Last active
August 15, 2017 06:03
-
-
Save basp1/cd9c407f167496175cae641d776bb047 to your computer and use it in GitHub Desktop.
Simple sparse matrix class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Diagnostics; | |
namespace Matlib | |
{ | |
public class SparseMatrix<T> | |
{ | |
public enum SparseMatrixType | |
{ | |
UNSYMMENTRIC, UPPER_SYMMENTRIC, LOWER_SYMMENTRIC | |
} | |
readonly static int NIL = -1; | |
SparseMatrixType type; | |
List<int> rows; | |
List<T> data; | |
List<int> columns; | |
List<int> next; | |
List<int> last; | |
List<int> count; | |
int ie; | |
int nnz; | |
int rowCount = 0; | |
int columnCount = 0; | |
public SparseMatrix(SparseMatrix<T> other) | |
{ | |
this.rows = new List<int>(other.rows); | |
this.data = new List<T>(other.data); | |
this.columns = new List<int>(other.columns); | |
this.next = new List<int>(other.next); | |
this.last = new List<int>(other.last); | |
this.count = new List<int>(other.count); | |
this.ie = other.ie; | |
this.nnz = other.nnz; | |
this.rowCount = other.rowCount; | |
this.columnCount = other.columnCount; | |
} | |
public SparseMatrix(int rowCount, int columnCount = int.MaxValue) | |
{ | |
Debug.Assert(rowCount > 0); | |
Debug.Assert(columnCount > 0); | |
this.rows = new List<int>(rowCount); | |
this.last = new List<int>(rowCount); | |
this.count = new List<int>(rowCount); | |
for (int i = 0; i < rowCount; i++) | |
{ | |
last.Add(NIL); | |
rows.Add(NIL); | |
count.Add(0); | |
} | |
this.data = new List<T>(rowCount); // | |
this.columns = new List<int>(rowCount); // | |
this.next = new List<int>(rowCount); // | |
this.ie = NIL; | |
this.nnz = 0; | |
this.rowCount = rowCount; | |
this.columnCount = columnCount; | |
} | |
public static SparseMatrix<double> Eye(int size) | |
{ | |
var spm = new SparseMatrix<double>(size, size); | |
for (int i = 0; i < size; i++) | |
{ | |
spm.PushBack(i, i, 1.0); | |
} | |
return spm; | |
} | |
public SparseMatrixType Type { get { return type; } } | |
public int RowCount { get { return rowCount; } } | |
public int ColumnCount { get { return columnCount; } } | |
public int Nnz { get { return nnz; } } | |
public override bool Equals(object obj) | |
{ | |
if (!(obj is SparseMatrix<T>)) | |
{ | |
throw new ArgumentException(); | |
} | |
var other = (SparseMatrix<T>)obj; | |
if (Nnz != other.Nnz) | |
{ | |
return false; | |
} | |
if (RowCount != other.RowCount) | |
{ | |
return false; | |
} | |
if (ColumnCount != other.ColumnCount) | |
{ | |
return false; | |
} | |
if (!count.SequenceEqual(other.count)) | |
{ | |
return false; | |
} | |
for (int i = 0; i < RowCount; i++) | |
{ | |
for (int j = rows[i], other_j = other.rows[i]; NIL != j && NIL != other_j; j = next[j], other_j = other.next[other_j]) | |
{ | |
if (columns[j] != other.columns[other_j]) | |
{ | |
return false; | |
} | |
if (!data[j].Equals(other.data[other_j])) | |
{ | |
return false; | |
} | |
} | |
} | |
return true; | |
} | |
public int Count(int row) | |
{ | |
Debug.Assert(row < rowCount); | |
return count[row]; | |
} | |
public void Clear() | |
{ | |
this.ie = NIL; | |
this.nnz = 0; | |
for (int i = 0; i < rowCount; i++) | |
{ | |
this.rows[i] = NIL; | |
this.count[i] = 0; | |
this.last[i] = NIL; | |
} | |
this.data.Clear(); | |
this.columns.Clear(); | |
this.next.Clear(); | |
} | |
public Cell<T> First(int row) | |
{ | |
Debug.Assert(Count(row) >= 1); | |
return new Cell<T> { Column = columns[rows[row]], Value = data[rows[row]] }; | |
} | |
public Cell<T> Second(int row) | |
{ | |
Debug.Assert(Count(row) >= 2); | |
return new Cell<T> { Column = columns[next[rows[row]]], Value = data[next[rows[row]]] }; | |
} | |
public void CopyRow(int row, ref List<Cell<T>> list) | |
{ | |
Debug.Assert(row < rowCount); | |
list.Clear(); | |
list.Capacity = Count(row); | |
for (int j = rows[row], i = 0; NIL != j; j = next[j], i++) | |
{ | |
list.Add(new Cell<T>() { Column = columns[j], Value = data[j] }); | |
} | |
} | |
public List<Cell<T>> CopyRow(int row) | |
{ | |
var list = new List<Cell<T>>(Count(row)); | |
CopyRow(row, ref list); | |
return list; | |
} | |
public bool Contains(int row, Predicate<Cell<T>> predicate) | |
{ | |
Debug.Assert(row < rowCount); | |
for (int j = rows[row]; NIL != j; j = next[j]) | |
{ | |
if (predicate(new Cell<T>() { Column = columns[j], Value = data[j] })) | |
{ | |
return true; | |
} | |
} | |
return false; | |
} | |
public void Push(int row, Cell<T> cell) | |
{ | |
Push(row, cell.Column, cell.Value); | |
} | |
public void Push(int row, int column, T value) | |
{ | |
Debug.Assert(row < rowCount); | |
Debug.Assert(column < columnCount); | |
int j; | |
if (ie >= 0) | |
{ | |
j = ie; | |
data[ie] = value; | |
columns[ie] = column; | |
ie = next[ie]; | |
} | |
else | |
{ | |
j = data.Count; | |
data.Add(value); | |
columns.Add(column); | |
next.Add(NIL); | |
} | |
if (0 == Count(row)) | |
{ | |
last[row] = j; | |
rows[row] = j; | |
next[j] = NIL; | |
} | |
else | |
{ | |
next[j] = rows[row]; | |
rows[row] = j; | |
} | |
count[row] += 1; | |
this.nnz += 1; | |
} | |
public void PushBack(int row, Cell<T> cell) | |
{ | |
PushBack(row, cell.Column, cell.Value); | |
} | |
public void PushBack(int row, int column, T value) | |
{ | |
int n; | |
if (ie >= 0) | |
{ | |
n = ie; | |
data[ie] = value; | |
columns[ie] = column; | |
ie = next[ie]; | |
} | |
else | |
{ | |
n = nnz; | |
next.Add(NIL); | |
data.Add(value); | |
columns.Add(column); | |
} | |
if (NIL == rows[row]) | |
{ | |
rows[row] = n; | |
} | |
else | |
{ | |
next[last[row]] = n; | |
} | |
last[row] = n; | |
next[n] = NIL; | |
count[row] += 1; | |
nnz += 1; | |
} | |
public void RemoveRow(int row) | |
{ | |
Debug.Assert(row < rowCount); | |
if (0 == Count(row)) | |
{ | |
return; | |
} | |
else | |
{ | |
for (int j = rows[row]; NIL != j; j = next[j]) | |
{ | |
data[j] = default(T); | |
} | |
next[last[row]] = ie; | |
ie = rows[row]; | |
nnz -= count[row]; | |
rows[row] = NIL; | |
last[row] = NIL; | |
count[row] = 0; | |
} | |
} | |
public void MergeRows(int dstRow, int srcRow) | |
{ | |
Debug.Assert(dstRow < rowCount); | |
Debug.Assert(srcRow < rowCount); | |
if (0 == Count(srcRow)) | |
{ | |
return; | |
} | |
int j = rows[srcRow]; | |
if (0 == Count(dstRow)) | |
{ | |
rows[dstRow] = rows[srcRow]; | |
} | |
else | |
{ | |
next[last[dstRow]] = rows[srcRow]; | |
} | |
last[dstRow] = last[srcRow]; | |
count[dstRow] += count[srcRow]; | |
rows[srcRow] = NIL; | |
last[srcRow] = NIL; | |
count[srcRow] = 0; | |
} | |
public SparseMatrix<T> Transpose() | |
{ | |
Debug.Assert(int.MaxValue != columnCount); | |
var t = new SparseMatrix<T>(columnCount, rowCount); | |
for (int i = 0; i < rowCount; i++) | |
{ | |
if (NIL == rows[i]) continue; | |
for (int j = rows[i]; NIL != j; j = next[j]) | |
{ | |
t.PushBack(columns[j], i, data[j]); | |
} | |
} | |
return t; | |
} | |
public void SortRows() | |
{ | |
var list = new List<Cell<T>>(); | |
for (int i = 0; i < rowCount; i++) | |
{ | |
if (NIL == rows[i]) continue; | |
CopyRow(i, ref list); | |
list.Sort((x, y) => x.Column.CompareTo(y.Column)); | |
RemoveRow(i); | |
for (int j = 0; j < list.Count; j++) | |
{ | |
PushBack(i, list[j]); | |
} | |
} | |
} | |
public struct Cell<T> | |
{ | |
public int Column; | |
public T Value; | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using Matlib; | |
using Microsoft.VisualStudio.TestTools.UnitTesting; | |
namespace Matlib.Tests | |
{ | |
[TestClass] | |
public class SparseMatrixTests | |
{ | |
[TestMethod] | |
public void SparseMatrixPush() | |
{ | |
var spm = new SparseMatrix<char>(3); | |
Assert.AreEqual(0, spm.Count(0)); | |
Assert.AreEqual(0, spm.Count(1)); | |
Assert.AreEqual(0, spm.Count(2)); | |
Assert.AreEqual(0, spm.Nnz); | |
spm.Push(1, 0, 'a'); | |
spm.Push(1, 0, 'b'); | |
spm.Push(1, 0, 'c'); | |
spm.Push(1, 0, 'd'); | |
Assert.AreEqual(0, spm.Count(0)); | |
Assert.AreEqual(4, spm.Count(1)); | |
Assert.AreEqual(0, spm.Count(2)); | |
Assert.AreEqual(4, spm.Nnz); | |
var list = spm.CopyRow(0); | |
Assert.AreEqual(0, list.Count); | |
list = spm.CopyRow(1); | |
Assert.AreEqual(4, list.Count); | |
Assert.AreEqual('d', list[0].Value); | |
Assert.AreEqual('c', list[1].Value); | |
Assert.AreEqual('b', list[2].Value); | |
Assert.AreEqual('a', list[3].Value); | |
} | |
public void SparseMatrixPushBack() | |
{ | |
var spm = new SparseMatrix<char>(3); | |
Assert.AreEqual(0, spm.Count(0)); | |
Assert.AreEqual(0, spm.Count(1)); | |
Assert.AreEqual(0, spm.Count(2)); | |
Assert.AreEqual(0, spm.Nnz); | |
spm.PushBack(1, 0, 'a'); | |
spm.PushBack(1, 0, 'b'); | |
spm.PushBack(1, 0, 'c'); | |
spm.PushBack(1, 0, 'd'); | |
Assert.AreEqual(0, spm.Count(0)); | |
Assert.AreEqual(4, spm.Count(1)); | |
Assert.AreEqual(0, spm.Count(2)); | |
Assert.AreEqual(4, spm.Nnz); | |
var list = spm.CopyRow(0); | |
Assert.AreEqual(0, list.Count); | |
list = spm.CopyRow(1); | |
Assert.AreEqual(4, list.Count); | |
Assert.AreEqual('a', list[0].Value); | |
Assert.AreEqual('b', list[1].Value); | |
Assert.AreEqual('c', list[2].Value); | |
Assert.AreEqual('d', list[3].Value); | |
} | |
[TestMethod] | |
public void SparseMatrixContains() | |
{ | |
var spm = new SparseMatrix<char>(3); | |
spm.Push(1, 0, 'a'); | |
spm.Push(1, 0, 'b'); | |
spm.Push(1, 0, 'b'); | |
spm.Push(1, 0, 'a'); | |
Assert.IsTrue(spm.Contains(1, cell => 'a' == cell.Value)); | |
Assert.IsTrue(spm.Contains(1, cell => 'b' == cell.Value)); | |
Assert.IsFalse(spm.Contains(1, cell => 'c' == cell.Value)); | |
} | |
[TestMethod] | |
public void SparseMatrixRemoveRow() | |
{ | |
var spm = new SparseMatrix<char>(3); | |
spm.Push(1, 0, 'a'); | |
spm.Push(1, 0, 'b'); | |
spm.Push(1, 0, 'b'); | |
spm.Push(1, 0, 'a'); | |
spm.RemoveRow(0); | |
Assert.AreEqual(0, spm.Count(0)); | |
Assert.AreEqual(4, spm.Count(1)); | |
Assert.AreEqual(0, spm.Count(2)); | |
Assert.AreEqual(4, spm.Nnz); | |
spm.RemoveRow(1); | |
Assert.AreEqual(0, spm.Count(0)); | |
Assert.AreEqual(0, spm.Count(1)); | |
Assert.AreEqual(0, spm.Count(2)); | |
Assert.AreEqual(0, spm.Nnz); | |
spm.Push(1, 0, 'a'); | |
spm.Push(1, 0, 'b'); | |
spm.Push(1, 0, 'b'); | |
spm.Push(1, 0, 'a'); | |
var list = spm.CopyRow(1); | |
Assert.AreEqual(4, list.Count); | |
Assert.AreEqual('a', list[0].Value); | |
Assert.AreEqual('b', list[1].Value); | |
Assert.AreEqual('b', list[2].Value); | |
Assert.AreEqual('a', list[3].Value); | |
} | |
[TestMethod] | |
public void SparseMatrixClear() | |
{ | |
var spm = new SparseMatrix<char>(3); | |
spm.Push(1, 0, 'a'); | |
spm.Push(1, 0, 'b'); | |
spm.Push(1, 0, 'b'); | |
spm.Push(1, 0, 'a'); | |
spm.Clear(); | |
Assert.AreEqual(0, spm.Count(0)); | |
Assert.AreEqual(0, spm.Count(1)); | |
Assert.AreEqual(0, spm.Count(2)); | |
Assert.AreEqual(0, spm.Nnz); | |
} | |
[TestMethod] | |
public void SparseMatrixMergeRows() | |
{ | |
var spm = new SparseMatrix<char>(3); | |
spm.Push(1, 0, 'a'); | |
spm.Push(1, 0, 'b'); | |
spm.Push(1, 0, 'b'); | |
spm.Push(1, 0, 'a'); | |
Assert.AreEqual(0, spm.Count(0)); | |
Assert.AreEqual(4, spm.Count(1)); | |
Assert.AreEqual(0, spm.Count(2)); | |
Assert.AreEqual(4, spm.Nnz); | |
spm.MergeRows(0, 1); | |
Assert.AreEqual(4, spm.Count(0)); | |
Assert.AreEqual(0, spm.Count(1)); | |
Assert.AreEqual(0, spm.Count(2)); | |
Assert.AreEqual(4, spm.Nnz); | |
var list = spm.CopyRow(0); | |
Assert.AreEqual(4, list.Count); | |
Assert.AreEqual('a', list[0].Value); | |
Assert.AreEqual('b', list[1].Value); | |
Assert.AreEqual('b', list[2].Value); | |
Assert.AreEqual('a', list[3].Value); | |
} | |
public void SparseMatrixMergeRows2() | |
{ | |
var spm = new SparseMatrix<char>(3); | |
spm.PushBack(1, 0, 'a'); | |
spm.PushBack(1, 0, 'b'); | |
spm.PushBack(1, 0, 'b'); | |
spm.PushBack(1, 0, 'a'); | |
Assert.AreEqual(0, spm.Count(0)); | |
Assert.AreEqual(4, spm.Count(1)); | |
Assert.AreEqual(0, spm.Count(2)); | |
Assert.AreEqual(4, spm.Nnz); | |
spm.MergeRows(0, 1); | |
Assert.AreEqual(4, spm.Count(0)); | |
Assert.AreEqual(0, spm.Count(1)); | |
Assert.AreEqual(0, spm.Count(2)); | |
Assert.AreEqual(4, spm.Nnz); | |
var list = spm.CopyRow(0); | |
Assert.AreEqual(4, list.Count); | |
Assert.AreEqual('a', list[0].Value); | |
Assert.AreEqual('b', list[1].Value); | |
Assert.AreEqual('b', list[2].Value); | |
Assert.AreEqual('a', list[3].Value); | |
} | |
[TestMethod] | |
public void SparseMatrixSortRows() | |
{ | |
var spm = new SparseMatrix<int>(3); | |
spm.Push(1, 3, 3); | |
spm.Push(1, 1, 1); | |
spm.Push(1, 0, 0); | |
spm.Push(1, 2, 2); | |
var list = spm.CopyRow(1); | |
Assert.AreEqual(4, list.Count); | |
Assert.AreEqual(2, list[0].Column); | |
Assert.AreEqual(0, list[1].Column); | |
Assert.AreEqual(1, list[2].Column); | |
Assert.AreEqual(3, list[3].Column); | |
spm.SortRows(); | |
list = spm.CopyRow(1); | |
Assert.AreEqual(4, list.Count); | |
Assert.AreEqual(0, list[0].Column); | |
Assert.AreEqual(1, list[1].Column); | |
Assert.AreEqual(2, list[2].Column); | |
Assert.AreEqual(3, list[3].Column); | |
} | |
[TestMethod] | |
public void SparseMatrixTranspose() | |
{ | |
var m = new SparseMatrix<int>(2, 3); | |
m.PushBack(0, 0, 0); m.PushBack(0, 1, 1); m.PushBack(0, 2, 2); | |
m.PushBack(1, 0, 3); m.PushBack(1, 1, 4); m.PushBack(1, 2, 5); | |
var t = m.Transpose(); | |
var tt = t.Transpose(); | |
var ttt = tt.Transpose(); | |
Assert.IsFalse(m.Equals(t)); | |
Assert.IsTrue(m.Equals(tt)); | |
Assert.IsTrue(t.Equals(ttt)); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment