Skip to content

Instantly share code, notes, and snippets.

@bluss
Created November 10, 2016 00:24
Show Gist options
  • Select an option

  • Save bluss/5a4462a1effb5c77f4d55ce045083171 to your computer and use it in GitHub Desktop.

Select an option

Save bluss/5a4462a1effb5c77f4d55ce045083171 to your computer and use it in GitHub Desktop.
Cargo.toml | 2 +-
src/eigenvalues/general.rs | 10 +++++-----
src/eigenvalues/symmetric.rs | 8 ++++----
src/eigenvalues/types.rs | 7 ++++---
src/impl_prelude.rs | 2 +-
src/least_squares.rs | 22 +++++++++++-----------
src/solve_linear/general.rs | 16 ++++++++--------
src/solve_linear/symmetric.rs | 16 ++++++++--------
src/svd/general.rs | 4 ++--
src/svd/types.rs | 7 ++++---
src/util/internal.rs | 10 +++++++---
11 files changed, 55 insertions(+), 49 deletions(-)
diff --git a/Cargo.toml b/Cargo.toml
index cf6b4f3..d174826 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -19,7 +19,7 @@ openblas = ["blas/openblas", "lapack/openblas"]
netlib = ["blas/netlib", "lapack/netlib"]
[dependencies]
-ndarray = "0.6"
+ndarray = { version = "0.6", path = ".." }
num-traits = "0.1"
rand = "0.3"
diff --git a/src/eigenvalues/general.rs b/src/eigenvalues/general.rs
index 980e00c..935f5e1 100644
--- a/src/eigenvalues/general.rs
+++ b/src/eigenvalues/general.rs
@@ -45,18 +45,18 @@ macro_rules! impl_eigen_real {
where D: DataMut<Elem=Self> + DataOwned<Elem=Self> {
let dim = mat.dim();
- if dim.0 != dim.1 {
+ if !dim.is_square() {
return Err(EigenError::NotSquare);
}
- let n = mat.dim().0 as i32;
+ let n = dim[0] as i32;
let (data_slice, layout, ld) = match slice_and_layout_mut(&mut mat) {
Some(s) => s,
None => return Err(EigenError::BadLayout)
};
- let mut vl = matrix_with_layout(if compute_left { dim } else { (0, 0) }, layout);
- let mut vr = matrix_with_layout(if compute_right { dim } else { (0, 0) }, layout);
+ let mut vl = matrix_with_layout(if compute_left { dim } else { Ix2::default() }, layout);
+ let mut vr = matrix_with_layout(if compute_right { dim } else { Ix2::default() }, layout);
let mut values_real_imag = vec![0.0; 2 * n as usize];
let (mut values_real, mut values_imag) = values_real_imag.split_at_mut(n as usize);
@@ -101,7 +101,7 @@ macro_rules! impl_eigen_complex {
-> Result<Self::Solution, EigenError>
where D: DataMut<Elem=Self> + DataOwned<Elem=Self> {
- let dim = mat.dim();
+ let dim = mat.dim_pattern();
if dim.0 != dim.1 {
return Err(EigenError::NotSquare);
}
diff --git a/src/eigenvalues/symmetric.rs b/src/eigenvalues/symmetric.rs
index c54928c..25a0c03 100644
--- a/src/eigenvalues/symmetric.rs
+++ b/src/eigenvalues/symmetric.rs
@@ -15,13 +15,13 @@ pub trait SymEigen: Sized {
fn compute_mut<D>(mat: &mut ArrayBase<D, Ix2>,
uplo: Symmetric,
with_vectors: bool)
- -> Result<Array<Self::SingularValue, Ix>, EigenError>
+ -> Result<Array<Self::SingularValue, Ix1>, EigenError>
where D: DataMut<Elem = Self>;
/// Return the real eigenvalues of a symmetric matrix.
fn compute_into<D>(mut mat: ArrayBase<D, Ix2>,
uplo: Symmetric)
- -> Result<Array<Self::SingularValue, Ix>, EigenError>
+ -> Result<Array<Self::SingularValue, Ix1>, EigenError>
where D: DataMut<Elem = Self> + DataOwned<Elem = Self> {
Self::compute_mut(&mut mat, uplo, false)
}
@@ -46,10 +46,10 @@ macro_rules! impl_sym_eigen {
type Solution = Solution<Self, Self::SingularValue>;
fn compute_mut<D>(mat: &mut ArrayBase<D, Ix2>, uplo: Symmetric, with_vectors: bool) ->
- Result<Array<Self::SingularValue, Ix>, EigenError>
+ Result<Array<Self::SingularValue, Ix1>, EigenError>
where D: DataMut<Elem=Self>
{
- let dim = mat.dim();
+ let dim = mat.dim_pattern();
if dim.0 != dim.1 {
return Err(EigenError::NotSquare);
}
diff --git a/src/eigenvalues/types.rs b/src/eigenvalues/types.rs
index d26d828..b95c166 100644
--- a/src/eigenvalues/types.rs
+++ b/src/eigenvalues/types.rs
@@ -1,4 +1,5 @@
use ndarray::prelude::*;
+use ndarray::{Ix1, Ix2};
/// Errors from an eigenvalue problem.
#[derive(Debug)]
@@ -15,7 +16,7 @@ pub enum EigenError {
/// eigenvectors of the solution. For symmetric problems, the
/// eigenvectors are placed in `right_eigenvectors`.
pub struct Solution<IV, EV> {
- pub values: Array<EV, Ix>,
- pub left_vectors: Option<Array<IV, (Ix, Ix)>>,
- pub right_vectors: Option<Array<IV, (Ix, Ix)>>,
+ pub values: Array<EV, Ix1>,
+ pub left_vectors: Option<Array<IV, Ix2>>,
+ pub right_vectors: Option<Array<IV, Ix2>>,
}
diff --git a/src/impl_prelude.rs b/src/impl_prelude.rs
index 8d34883..b0d62e4 100644
--- a/src/impl_prelude.rs
+++ b/src/impl_prelude.rs
@@ -1,5 +1,5 @@
pub use ndarray::prelude::*;
-pub use ndarray::{Data, DataMut, DataOwned, Ix2};
+pub use ndarray::{Data, DataMut, DataOwned, Ix1, Ix2};
pub use util::internal::*;
pub use lapack::{c32, c64};
pub use types::Symmetric;
diff --git a/src/least_squares.rs b/src/least_squares.rs
index 7e5c995..e8d4928 100644
--- a/src/least_squares.rs
+++ b/src/least_squares.rs
@@ -142,12 +142,12 @@ pub trait LeastSquares: Sized + Clone {
///
/// This method will never return `LeastSquaresError::Degenerate`.
fn compute<D1, D2>(a: &ArrayBase<D1, Ix2>,
- b: &ArrayBase<D2, Ix>)
- -> Result<LeastSquaresSolution<Self, Ix>, LeastSquaresError>
+ b: &ArrayBase<D2, Ix1>)
+ -> Result<LeastSquaresSolution<Self, Ix1>, LeastSquaresError>
where D1: Data<Elem = Self>,
D2: Data<Elem = Self>
{
- let n = b.dim();
+ let n = b.len();
// Create a new matrix, where the column vector is a degenerate 2-D matrix.
let b_mat = match b.to_owned().into_shape((n, 1)) {
@@ -174,14 +174,14 @@ fn resize_solution<T: Clone + Default, D>(mut b_sol: ArrayBase<D, Ix2>, n: usize
{
let b_dim = b_sol.dim();
- if b_dim.0 > n {
+ if b_dim[0] > n {
// If the matrix is overdetermined, we just need to truncate the
// solution.
b_sol.slice_mut(s![0..n as isize, ..]).to_owned()
} else {
// Otherwise, it's underdetermined, and we need to extend the solution
- let mut extended_sol = Array::default((n, b_dim.1));
- extended_sol.slice_mut(s![0..b_dim.0 as isize, ..]).assign(&b_sol);
+ let mut extended_sol = Array::default((n, b_dim[1]));
+ extended_sol.slice_mut(s![0..b_dim[0] as isize, ..]).assign(&b_sol);
extended_sol
}
}
@@ -197,8 +197,8 @@ macro_rules! impl_least_squares {
where D1: DataMut<Elem=Self> + DataOwned<Elem = Self>,
D2: DataMut<Elem=Self> + DataOwned<Elem = Self> {
- let a_dim = a.dim();
- let b_dim = b.dim();
+ let a_dim = a.dim_pattern();
+ let b_dim = b.dim_pattern();
// confirm same number of rows.
if a_dim.0 != b_dim.0 {
@@ -242,8 +242,8 @@ macro_rules! impl_least_squares {
where D1: DataMut<Elem=Self> + DataOwned<Elem = Self>,
D2: DataMut<Elem=Self> + DataOwned<Elem = Self> {
- let a_dim = a.dim();
- let b_dim = b.dim();
+ let a_dim = a.dim_pattern();
+ let b_dim = b.dim_pattern();
// confirm same number of rows.
if a_dim.0 != b_dim.0 {
@@ -256,7 +256,7 @@ macro_rules! impl_least_squares {
None => return Err(LeastSquaresError::BadLayout)
};
- let mut svs: Array<$sv_type, Ix> = Array::default(cmp::min(a_dim.0, a_dim.1));
+ let mut svs: Array<$sv_type, Ix1> = Array::default(cmp::min(a_dim.0, a_dim.1));
let mut rank: i32 = 0;
// compute result
diff --git a/src/solve_linear/general.rs b/src/solve_linear/general.rs
index d301653..506e90a 100644
--- a/src/solve_linear/general.rs
+++ b/src/solve_linear/general.rs
@@ -16,12 +16,12 @@ pub trait SolveLinear: Sized + Clone {
/// Solve the linear system A * x = b for square matrix `a` and column vector `b`.
fn compute_into<D1, D2>(a: ArrayBase<D1, Ix2>,
- b: ArrayBase<D2, Ix>)
- -> Result<ArrayBase<D2, Ix>, SolveError>
+ b: ArrayBase<D2, Ix1>)
+ -> Result<ArrayBase<D2, Ix1>, SolveError>
where D1: DataMut<Elem = Self> + DataOwned<Elem = Self>,
D2: DataMut<Elem = Self> + DataOwned<Elem = Self>
{
- let n = b.dim();
+ let n = b.len();
// Create a new matrix, where the column vector is a degenerate 2-D matrix.
let b_mat = match b.into_shape((n, 1)) {
@@ -51,8 +51,8 @@ pub trait SolveLinear: Sized + Clone {
/// Solve the linear system A * x = b for square matrix `a` and column vector `b`.
fn compute<D1, D2>(a: &ArrayBase<D1, Ix2>,
- b: &ArrayBase<D2, Ix>)
- -> Result<Array<Self, Ix>, SolveError>
+ b: &ArrayBase<D2, Ix1>)
+ -> Result<Array<Self, Ix1>, SolveError>
where D1: Data<Elem = Self>,
D2: Data<Elem = Self>
{
@@ -72,8 +72,8 @@ macro_rules! impl_solve_linear {
D2: DataMut<Elem=Self> + DataOwned<Elem = Self> {
// Make sure the input is square.
- let dim = a.dim();
- let b_dim = b.dim();
+ let dim = a.dim_pattern();
+ let b_dim = b.dim_pattern();
if dim.0 != dim.1 {
return Err(SolveError::NotSquare(dim.0, dim.1));
@@ -93,7 +93,7 @@ macro_rules! impl_solve_linear {
None => return Err(SolveError::InconsistentLayout)
};
- let mut perm: Array<i32, Ix> = Array::default(dim.0);
+ let mut perm: Array<i32, Ix1> = Array::default(dim.0);
$driver(layout, dim.0 as i32, b_dim.1 as i32,
slice, lda as i32,
diff --git a/src/solve_linear/symmetric.rs b/src/solve_linear/symmetric.rs
index 8370af8..43dc34e 100644
--- a/src/solve_linear/symmetric.rs
+++ b/src/solve_linear/symmetric.rs
@@ -19,12 +19,12 @@ pub trait SymmetricSolveLinear: Sized + Clone {
/// square matrix `a` and column vector `b`.
fn compute_into<D1, D2>(a: ArrayBase<D1, Ix2>,
uplo: Symmetric,
- b: ArrayBase<D2, Ix>)
- -> Result<ArrayBase<D2, Ix>, SolveError>
+ b: ArrayBase<D2, Ix1>)
+ -> Result<ArrayBase<D2, Ix1>, SolveError>
where D1: DataMut<Elem = Self> + DataOwned<Elem = Self>,
D2: DataMut<Elem = Self> + DataOwned<Elem = Self>
{
- let n = b.dim();
+ let n = b.len();
// Create a new matrix, where the column vector is a degenerate 2-D matrix.
let b_mat = match b.into_shape((n, 1)) {
@@ -58,8 +58,8 @@ pub trait SymmetricSolveLinear: Sized + Clone {
/// square matrix `a` and column vector `b`.
fn compute<D1, D2>(a: &ArrayBase<D1, Ix2>,
uplo: Symmetric,
- b: &ArrayBase<D2, Ix>)
- -> Result<Array<Self, Ix>, SolveError>
+ b: &ArrayBase<D2, Ix1>)
+ -> Result<Array<Self, Ix1>, SolveError>
where D1: Data<Elem = Self>,
D2: Data<Elem = Self>
{
@@ -81,8 +81,8 @@ macro_rules! impl_solve_linear {
D2: DataMut<Elem=Self> + DataOwned<Elem = Self> {
// Make sure the input is square.
- let dim = a.dim();
- let b_dim = b.dim();
+ let dim = a.dim_pattern();
+ let b_dim = b.dim_pattern();
if dim.0 != dim.1 {
return Err(SolveError::NotSquare(dim.0, dim.1));
@@ -102,7 +102,7 @@ macro_rules! impl_solve_linear {
None => return Err(SolveError::InconsistentLayout)
};
- let mut perm: Array<i32, Ix> = Array::default(dim.0);
+ let mut perm: Array<i32, Ix1> = Array::default(dim.0);
$driver(layout, uplo as u8, dim.0 as i32, b_dim.1 as i32,
slice, lda as i32,
diff --git a/src/svd/general.rs b/src/svd/general.rs
index e222bc4..8a786cb 100644
--- a/src/svd/general.rs
+++ b/src/svd/general.rs
@@ -51,7 +51,7 @@ enum SVDMethod {
/// Choose a method based on the problem.
fn select_svd_method(d: &Ix2, compute_either: bool) -> SVDMethod {
- let mx = cmp::max(d.0, d.1);
+ let mx = cmp::max(d[0], d[1]);
// When we're computing one of them singular vector sets, we have
// to compute both with divide and conquer. So, we're bound by the
@@ -79,7 +79,7 @@ macro_rules! impl_svd {
where D: DataMut<Elem=Self> + DataOwned<Elem = Self>{
let dim = mat.dim();
- let (m, n) = dim;
+ let (m, n) = dim.into_pattern();
let mut s = Array::default(cmp::min(m, n));
let (slice, layout, lda) = match slice_and_layout_mut(&mut mat) {
diff --git a/src/svd/types.rs b/src/svd/types.rs
index aa236d7..7c65118 100644
--- a/src/svd/types.rs
+++ b/src/svd/types.rs
@@ -1,6 +1,7 @@
use ndarray::prelude::*;
use num_traits::{Float, ToPrimitive};
use std::fmt::Display;
+use ndarray::{Ix1, Ix2};
/// Trait for singular values
///
@@ -21,15 +22,15 @@ pub struct SVDSolution<IV: Sized, SV: Sized> {
///
/// Singular values, which are guaranteed to be non-negative
/// reals, are returned in descending order.
- pub values: Array<SV, Ix>,
+ pub values: Array<SV, Ix1>,
/// The matrix `U` of left singular vectors.
- pub left_vectors: Option<Array<IV, (Ix, Ix)>>,
+ pub left_vectors: Option<Array<IV, Ix2>>,
/// The matrix V^t of singular vectors.
///
/// The transpose of V is stored, not V itself.
- pub right_vectors: Option<Array<IV, (Ix, Ix)>>,
+ pub right_vectors: Option<Array<IV, Ix2>>,
}
/// An error resulting from a `SVD::compute*` method.
diff --git a/src/util/internal.rs b/src/util/internal.rs
index dded08f..619ef5e 100644
--- a/src/util/internal.rs
+++ b/src/util/internal.rs
@@ -1,12 +1,16 @@
use ndarray::prelude::*;
use ndarray::{Ix2, DataMut};
+use ndarray::IntoDimension;
use lapack::c::Layout;
use std::slice;
/// Return an array with the specified dimensions and layout.
///
/// This function is used internally to ensure that
-pub fn matrix_with_layout<T: Default>(d: Ix2, layout: Layout) -> Array<T, Ix2> {
+pub fn matrix_with_layout<P, T: Default>(d: P, layout: Layout) -> Array<T, Ix2>
+ where P: IntoDimension<Dim=Ix2>,
+{
+ let d = d.into_dimension();
Array::default(match layout {
Layout::RowMajor => d.into(),
Layout::ColumnMajor => d.f(),
@@ -29,7 +33,7 @@ pub fn slice_and_layout_mut<D, S: DataMut<Elem = D>>(mat: &mut ArrayBase<S, Ix2>
return None;
}
- let dim = mat.dim();
+ let dim = mat.dim_pattern();
// One of the stides, must be 1
if strides.1 == 1 {
@@ -63,7 +67,7 @@ pub fn slice_and_layout_mut<D, S: DataMut<Elem = D>>(mat: &mut ArrayBase<S, Ix2>
pub fn slice_and_layout_matching_mut<D, S: DataMut<Elem = D>>(mat: &mut ArrayBase<S, Ix2>,
layout: Layout)
-> Option<(&mut [S::Elem], Ixs)> {
- let dim = mat.dim();
+ let dim = mat.dim_pattern();
// For column vectors, we can choose whatever layout we want.
if dim.1 == 1 {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment