-
-
Save bluss/5a4462a1effb5c77f4d55ce045083171 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Cargo.toml | 2 +- | |
| src/eigenvalues/general.rs | 10 +++++----- | |
| src/eigenvalues/symmetric.rs | 8 ++++---- | |
| src/eigenvalues/types.rs | 7 ++++--- | |
| src/impl_prelude.rs | 2 +- | |
| src/least_squares.rs | 22 +++++++++++----------- | |
| src/solve_linear/general.rs | 16 ++++++++-------- | |
| src/solve_linear/symmetric.rs | 16 ++++++++-------- | |
| src/svd/general.rs | 4 ++-- | |
| src/svd/types.rs | 7 ++++--- | |
| src/util/internal.rs | 10 +++++++--- | |
| 11 files changed, 55 insertions(+), 49 deletions(-) | |
| diff --git a/Cargo.toml b/Cargo.toml | |
| index cf6b4f3..d174826 100644 | |
| --- a/Cargo.toml | |
| +++ b/Cargo.toml | |
| @@ -19,7 +19,7 @@ openblas = ["blas/openblas", "lapack/openblas"] | |
| netlib = ["blas/netlib", "lapack/netlib"] | |
| [dependencies] | |
| -ndarray = "0.6" | |
| +ndarray = { version = "0.6", path = ".." } | |
| num-traits = "0.1" | |
| rand = "0.3" | |
| diff --git a/src/eigenvalues/general.rs b/src/eigenvalues/general.rs | |
| index 980e00c..935f5e1 100644 | |
| --- a/src/eigenvalues/general.rs | |
| +++ b/src/eigenvalues/general.rs | |
| @@ -45,18 +45,18 @@ macro_rules! impl_eigen_real { | |
| where D: DataMut<Elem=Self> + DataOwned<Elem=Self> { | |
| let dim = mat.dim(); | |
| - if dim.0 != dim.1 { | |
| + if !dim.is_square() { | |
| return Err(EigenError::NotSquare); | |
| } | |
| - let n = mat.dim().0 as i32; | |
| + let n = dim[0] as i32; | |
| let (data_slice, layout, ld) = match slice_and_layout_mut(&mut mat) { | |
| Some(s) => s, | |
| None => return Err(EigenError::BadLayout) | |
| }; | |
| - let mut vl = matrix_with_layout(if compute_left { dim } else { (0, 0) }, layout); | |
| - let mut vr = matrix_with_layout(if compute_right { dim } else { (0, 0) }, layout); | |
| + let mut vl = matrix_with_layout(if compute_left { dim } else { Ix2::default() }, layout); | |
| + let mut vr = matrix_with_layout(if compute_right { dim } else { Ix2::default() }, layout); | |
| let mut values_real_imag = vec![0.0; 2 * n as usize]; | |
| let (mut values_real, mut values_imag) = values_real_imag.split_at_mut(n as usize); | |
| @@ -101,7 +101,7 @@ macro_rules! impl_eigen_complex { | |
| -> Result<Self::Solution, EigenError> | |
| where D: DataMut<Elem=Self> + DataOwned<Elem=Self> { | |
| - let dim = mat.dim(); | |
| + let dim = mat.dim_pattern(); | |
| if dim.0 != dim.1 { | |
| return Err(EigenError::NotSquare); | |
| } | |
| diff --git a/src/eigenvalues/symmetric.rs b/src/eigenvalues/symmetric.rs | |
| index c54928c..25a0c03 100644 | |
| --- a/src/eigenvalues/symmetric.rs | |
| +++ b/src/eigenvalues/symmetric.rs | |
| @@ -15,13 +15,13 @@ pub trait SymEigen: Sized { | |
| fn compute_mut<D>(mat: &mut ArrayBase<D, Ix2>, | |
| uplo: Symmetric, | |
| with_vectors: bool) | |
| - -> Result<Array<Self::SingularValue, Ix>, EigenError> | |
| + -> Result<Array<Self::SingularValue, Ix1>, EigenError> | |
| where D: DataMut<Elem = Self>; | |
| /// Return the real eigenvalues of a symmetric matrix. | |
| fn compute_into<D>(mut mat: ArrayBase<D, Ix2>, | |
| uplo: Symmetric) | |
| - -> Result<Array<Self::SingularValue, Ix>, EigenError> | |
| + -> Result<Array<Self::SingularValue, Ix1>, EigenError> | |
| where D: DataMut<Elem = Self> + DataOwned<Elem = Self> { | |
| Self::compute_mut(&mut mat, uplo, false) | |
| } | |
| @@ -46,10 +46,10 @@ macro_rules! impl_sym_eigen { | |
| type Solution = Solution<Self, Self::SingularValue>; | |
| fn compute_mut<D>(mat: &mut ArrayBase<D, Ix2>, uplo: Symmetric, with_vectors: bool) -> | |
| - Result<Array<Self::SingularValue, Ix>, EigenError> | |
| + Result<Array<Self::SingularValue, Ix1>, EigenError> | |
| where D: DataMut<Elem=Self> | |
| { | |
| - let dim = mat.dim(); | |
| + let dim = mat.dim_pattern(); | |
| if dim.0 != dim.1 { | |
| return Err(EigenError::NotSquare); | |
| } | |
| diff --git a/src/eigenvalues/types.rs b/src/eigenvalues/types.rs | |
| index d26d828..b95c166 100644 | |
| --- a/src/eigenvalues/types.rs | |
| +++ b/src/eigenvalues/types.rs | |
| @@ -1,4 +1,5 @@ | |
| use ndarray::prelude::*; | |
| +use ndarray::{Ix1, Ix2}; | |
| /// Errors from an eigenvalue problem. | |
| #[derive(Debug)] | |
| @@ -15,7 +16,7 @@ pub enum EigenError { | |
| /// eigenvectors of the solution. For symmetric problems, the | |
| /// eigenvectors are placed in `right_eigenvectors`. | |
| pub struct Solution<IV, EV> { | |
| - pub values: Array<EV, Ix>, | |
| - pub left_vectors: Option<Array<IV, (Ix, Ix)>>, | |
| - pub right_vectors: Option<Array<IV, (Ix, Ix)>>, | |
| + pub values: Array<EV, Ix1>, | |
| + pub left_vectors: Option<Array<IV, Ix2>>, | |
| + pub right_vectors: Option<Array<IV, Ix2>>, | |
| } | |
| diff --git a/src/impl_prelude.rs b/src/impl_prelude.rs | |
| index 8d34883..b0d62e4 100644 | |
| --- a/src/impl_prelude.rs | |
| +++ b/src/impl_prelude.rs | |
| @@ -1,5 +1,5 @@ | |
| pub use ndarray::prelude::*; | |
| -pub use ndarray::{Data, DataMut, DataOwned, Ix2}; | |
| +pub use ndarray::{Data, DataMut, DataOwned, Ix1, Ix2}; | |
| pub use util::internal::*; | |
| pub use lapack::{c32, c64}; | |
| pub use types::Symmetric; | |
| diff --git a/src/least_squares.rs b/src/least_squares.rs | |
| index 7e5c995..e8d4928 100644 | |
| --- a/src/least_squares.rs | |
| +++ b/src/least_squares.rs | |
| @@ -142,12 +142,12 @@ pub trait LeastSquares: Sized + Clone { | |
| /// | |
| /// This method will never return `LeastSquaresError::Degenerate`. | |
| fn compute<D1, D2>(a: &ArrayBase<D1, Ix2>, | |
| - b: &ArrayBase<D2, Ix>) | |
| - -> Result<LeastSquaresSolution<Self, Ix>, LeastSquaresError> | |
| + b: &ArrayBase<D2, Ix1>) | |
| + -> Result<LeastSquaresSolution<Self, Ix1>, LeastSquaresError> | |
| where D1: Data<Elem = Self>, | |
| D2: Data<Elem = Self> | |
| { | |
| - let n = b.dim(); | |
| + let n = b.len(); | |
| // Create a new matrix, where the column vector is a degenerate 2-D matrix. | |
| let b_mat = match b.to_owned().into_shape((n, 1)) { | |
| @@ -174,14 +174,14 @@ fn resize_solution<T: Clone + Default, D>(mut b_sol: ArrayBase<D, Ix2>, n: usize | |
| { | |
| let b_dim = b_sol.dim(); | |
| - if b_dim.0 > n { | |
| + if b_dim[0] > n { | |
| // If the matrix is overdetermined, we just need to truncate the | |
| // solution. | |
| b_sol.slice_mut(s![0..n as isize, ..]).to_owned() | |
| } else { | |
| // Otherwise, it's underdetermined, and we need to extend the solution | |
| - let mut extended_sol = Array::default((n, b_dim.1)); | |
| - extended_sol.slice_mut(s![0..b_dim.0 as isize, ..]).assign(&b_sol); | |
| + let mut extended_sol = Array::default((n, b_dim[1])); | |
| + extended_sol.slice_mut(s![0..b_dim[0] as isize, ..]).assign(&b_sol); | |
| extended_sol | |
| } | |
| } | |
| @@ -197,8 +197,8 @@ macro_rules! impl_least_squares { | |
| where D1: DataMut<Elem=Self> + DataOwned<Elem = Self>, | |
| D2: DataMut<Elem=Self> + DataOwned<Elem = Self> { | |
| - let a_dim = a.dim(); | |
| - let b_dim = b.dim(); | |
| + let a_dim = a.dim_pattern(); | |
| + let b_dim = b.dim_pattern(); | |
| // confirm same number of rows. | |
| if a_dim.0 != b_dim.0 { | |
| @@ -242,8 +242,8 @@ macro_rules! impl_least_squares { | |
| where D1: DataMut<Elem=Self> + DataOwned<Elem = Self>, | |
| D2: DataMut<Elem=Self> + DataOwned<Elem = Self> { | |
| - let a_dim = a.dim(); | |
| - let b_dim = b.dim(); | |
| + let a_dim = a.dim_pattern(); | |
| + let b_dim = b.dim_pattern(); | |
| // confirm same number of rows. | |
| if a_dim.0 != b_dim.0 { | |
| @@ -256,7 +256,7 @@ macro_rules! impl_least_squares { | |
| None => return Err(LeastSquaresError::BadLayout) | |
| }; | |
| - let mut svs: Array<$sv_type, Ix> = Array::default(cmp::min(a_dim.0, a_dim.1)); | |
| + let mut svs: Array<$sv_type, Ix1> = Array::default(cmp::min(a_dim.0, a_dim.1)); | |
| let mut rank: i32 = 0; | |
| // compute result | |
| diff --git a/src/solve_linear/general.rs b/src/solve_linear/general.rs | |
| index d301653..506e90a 100644 | |
| --- a/src/solve_linear/general.rs | |
| +++ b/src/solve_linear/general.rs | |
| @@ -16,12 +16,12 @@ pub trait SolveLinear: Sized + Clone { | |
| /// Solve the linear system A * x = b for square matrix `a` and column vector `b`. | |
| fn compute_into<D1, D2>(a: ArrayBase<D1, Ix2>, | |
| - b: ArrayBase<D2, Ix>) | |
| - -> Result<ArrayBase<D2, Ix>, SolveError> | |
| + b: ArrayBase<D2, Ix1>) | |
| + -> Result<ArrayBase<D2, Ix1>, SolveError> | |
| where D1: DataMut<Elem = Self> + DataOwned<Elem = Self>, | |
| D2: DataMut<Elem = Self> + DataOwned<Elem = Self> | |
| { | |
| - let n = b.dim(); | |
| + let n = b.len(); | |
| // Create a new matrix, where the column vector is a degenerate 2-D matrix. | |
| let b_mat = match b.into_shape((n, 1)) { | |
| @@ -51,8 +51,8 @@ pub trait SolveLinear: Sized + Clone { | |
| /// Solve the linear system A * x = b for square matrix `a` and column vector `b`. | |
| fn compute<D1, D2>(a: &ArrayBase<D1, Ix2>, | |
| - b: &ArrayBase<D2, Ix>) | |
| - -> Result<Array<Self, Ix>, SolveError> | |
| + b: &ArrayBase<D2, Ix1>) | |
| + -> Result<Array<Self, Ix1>, SolveError> | |
| where D1: Data<Elem = Self>, | |
| D2: Data<Elem = Self> | |
| { | |
| @@ -72,8 +72,8 @@ macro_rules! impl_solve_linear { | |
| D2: DataMut<Elem=Self> + DataOwned<Elem = Self> { | |
| // Make sure the input is square. | |
| - let dim = a.dim(); | |
| - let b_dim = b.dim(); | |
| + let dim = a.dim_pattern(); | |
| + let b_dim = b.dim_pattern(); | |
| if dim.0 != dim.1 { | |
| return Err(SolveError::NotSquare(dim.0, dim.1)); | |
| @@ -93,7 +93,7 @@ macro_rules! impl_solve_linear { | |
| None => return Err(SolveError::InconsistentLayout) | |
| }; | |
| - let mut perm: Array<i32, Ix> = Array::default(dim.0); | |
| + let mut perm: Array<i32, Ix1> = Array::default(dim.0); | |
| $driver(layout, dim.0 as i32, b_dim.1 as i32, | |
| slice, lda as i32, | |
| diff --git a/src/solve_linear/symmetric.rs b/src/solve_linear/symmetric.rs | |
| index 8370af8..43dc34e 100644 | |
| --- a/src/solve_linear/symmetric.rs | |
| +++ b/src/solve_linear/symmetric.rs | |
| @@ -19,12 +19,12 @@ pub trait SymmetricSolveLinear: Sized + Clone { | |
| /// square matrix `a` and column vector `b`. | |
| fn compute_into<D1, D2>(a: ArrayBase<D1, Ix2>, | |
| uplo: Symmetric, | |
| - b: ArrayBase<D2, Ix>) | |
| - -> Result<ArrayBase<D2, Ix>, SolveError> | |
| + b: ArrayBase<D2, Ix1>) | |
| + -> Result<ArrayBase<D2, Ix1>, SolveError> | |
| where D1: DataMut<Elem = Self> + DataOwned<Elem = Self>, | |
| D2: DataMut<Elem = Self> + DataOwned<Elem = Self> | |
| { | |
| - let n = b.dim(); | |
| + let n = b.len(); | |
| // Create a new matrix, where the column vector is a degenerate 2-D matrix. | |
| let b_mat = match b.into_shape((n, 1)) { | |
| @@ -58,8 +58,8 @@ pub trait SymmetricSolveLinear: Sized + Clone { | |
| /// square matrix `a` and column vector `b`. | |
| fn compute<D1, D2>(a: &ArrayBase<D1, Ix2>, | |
| uplo: Symmetric, | |
| - b: &ArrayBase<D2, Ix>) | |
| - -> Result<Array<Self, Ix>, SolveError> | |
| + b: &ArrayBase<D2, Ix1>) | |
| + -> Result<Array<Self, Ix1>, SolveError> | |
| where D1: Data<Elem = Self>, | |
| D2: Data<Elem = Self> | |
| { | |
| @@ -81,8 +81,8 @@ macro_rules! impl_solve_linear { | |
| D2: DataMut<Elem=Self> + DataOwned<Elem = Self> { | |
| // Make sure the input is square. | |
| - let dim = a.dim(); | |
| - let b_dim = b.dim(); | |
| + let dim = a.dim_pattern(); | |
| + let b_dim = b.dim_pattern(); | |
| if dim.0 != dim.1 { | |
| return Err(SolveError::NotSquare(dim.0, dim.1)); | |
| @@ -102,7 +102,7 @@ macro_rules! impl_solve_linear { | |
| None => return Err(SolveError::InconsistentLayout) | |
| }; | |
| - let mut perm: Array<i32, Ix> = Array::default(dim.0); | |
| + let mut perm: Array<i32, Ix1> = Array::default(dim.0); | |
| $driver(layout, uplo as u8, dim.0 as i32, b_dim.1 as i32, | |
| slice, lda as i32, | |
| diff --git a/src/svd/general.rs b/src/svd/general.rs | |
| index e222bc4..8a786cb 100644 | |
| --- a/src/svd/general.rs | |
| +++ b/src/svd/general.rs | |
| @@ -51,7 +51,7 @@ enum SVDMethod { | |
| /// Choose a method based on the problem. | |
| fn select_svd_method(d: &Ix2, compute_either: bool) -> SVDMethod { | |
| - let mx = cmp::max(d.0, d.1); | |
| + let mx = cmp::max(d[0], d[1]); | |
| // When we're computing one of them singular vector sets, we have | |
| // to compute both with divide and conquer. So, we're bound by the | |
| @@ -79,7 +79,7 @@ macro_rules! impl_svd { | |
| where D: DataMut<Elem=Self> + DataOwned<Elem = Self>{ | |
| let dim = mat.dim(); | |
| - let (m, n) = dim; | |
| + let (m, n) = dim.into_pattern(); | |
| let mut s = Array::default(cmp::min(m, n)); | |
| let (slice, layout, lda) = match slice_and_layout_mut(&mut mat) { | |
| diff --git a/src/svd/types.rs b/src/svd/types.rs | |
| index aa236d7..7c65118 100644 | |
| --- a/src/svd/types.rs | |
| +++ b/src/svd/types.rs | |
| @@ -1,6 +1,7 @@ | |
| use ndarray::prelude::*; | |
| use num_traits::{Float, ToPrimitive}; | |
| use std::fmt::Display; | |
| +use ndarray::{Ix1, Ix2}; | |
| /// Trait for singular values | |
| /// | |
| @@ -21,15 +22,15 @@ pub struct SVDSolution<IV: Sized, SV: Sized> { | |
| /// | |
| /// Singular values, which are guaranteed to be non-negative | |
| /// reals, are returned in descending order. | |
| - pub values: Array<SV, Ix>, | |
| + pub values: Array<SV, Ix1>, | |
| /// The matrix `U` of left singular vectors. | |
| - pub left_vectors: Option<Array<IV, (Ix, Ix)>>, | |
| + pub left_vectors: Option<Array<IV, Ix2>>, | |
| /// The matrix V^t of singular vectors. | |
| /// | |
| /// The transpose of V is stored, not V itself. | |
| - pub right_vectors: Option<Array<IV, (Ix, Ix)>>, | |
| + pub right_vectors: Option<Array<IV, Ix2>>, | |
| } | |
| /// An error resulting from a `SVD::compute*` method. | |
| diff --git a/src/util/internal.rs b/src/util/internal.rs | |
| index dded08f..619ef5e 100644 | |
| --- a/src/util/internal.rs | |
| +++ b/src/util/internal.rs | |
| @@ -1,12 +1,16 @@ | |
| use ndarray::prelude::*; | |
| use ndarray::{Ix2, DataMut}; | |
| +use ndarray::IntoDimension; | |
| use lapack::c::Layout; | |
| use std::slice; | |
| /// Return an array with the specified dimensions and layout. | |
| /// | |
| /// This function is used internally to ensure that | |
| -pub fn matrix_with_layout<T: Default>(d: Ix2, layout: Layout) -> Array<T, Ix2> { | |
| +pub fn matrix_with_layout<P, T: Default>(d: P, layout: Layout) -> Array<T, Ix2> | |
| + where P: IntoDimension<Dim=Ix2>, | |
| +{ | |
| + let d = d.into_dimension(); | |
| Array::default(match layout { | |
| Layout::RowMajor => d.into(), | |
| Layout::ColumnMajor => d.f(), | |
| @@ -29,7 +33,7 @@ pub fn slice_and_layout_mut<D, S: DataMut<Elem = D>>(mat: &mut ArrayBase<S, Ix2> | |
| return None; | |
| } | |
| - let dim = mat.dim(); | |
| + let dim = mat.dim_pattern(); | |
| // One of the stides, must be 1 | |
| if strides.1 == 1 { | |
| @@ -63,7 +67,7 @@ pub fn slice_and_layout_mut<D, S: DataMut<Elem = D>>(mat: &mut ArrayBase<S, Ix2> | |
| pub fn slice_and_layout_matching_mut<D, S: DataMut<Elem = D>>(mat: &mut ArrayBase<S, Ix2>, | |
| layout: Layout) | |
| -> Option<(&mut [S::Elem], Ixs)> { | |
| - let dim = mat.dim(); | |
| + let dim = mat.dim_pattern(); | |
| // For column vectors, we can choose whatever layout we want. | |
| if dim.1 == 1 { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment