Last active
May 10, 2017 18:53
-
-
Save tyler274/3c0a4e1060aa50025966975cf7952777 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module SparseMatrix where | |
import qualified Data.Map as M | |
import qualified Data.Set as S | |
import Data.List (transpose, findIndex, elemIndices) | |
import qualified Data.Vector.Unboxed as U | |
import Debug.Trace | |
import Data.Maybe | |
data SparseMatrix a = | |
SM { bounds :: (Integer, Integer), -- number of rows, columns | |
rowIndices :: S.Set Integer, -- row indices with nonzeros | |
colIndices :: S.Set Integer, -- column indices with nonzeros | |
vals :: M.Map (Integer, Integer) a } -- values | |
deriving (Eq, Show) | |
-- size of vector and index list of elements. | |
data SparseVector a = SV | |
{ size :: Integer | |
, vec :: M.Map Integer a | |
} deriving Eq | |
sparseMatrix :: (Eq a, Num a) => | |
[((Integer, Integer), a)] -> (Integer, Integer) -> SparseMatrix a | |
-- sparseMatrix <list of index/element pairs> <bounds> -> sparse matrix | |
sparseMatrix pairs bounds@(rows, cols) | |
| rows < 1 = error "row bounds less than 1" | |
| cols < 1 = error "column bounds less than 1" | |
| not (all (\((row, col), _) -> (row <= rows) && (col <= cols)) pairs) = | |
error "index pair not in bounds" | |
| otherwise = | |
let no_zeros = filter (\((_, _), val) -> val /= 0) pairs in | |
let new_vals = M.fromList no_zeros in | |
let new_rowIndices = S.fromList (map fst (M.keys new_vals)) | |
new_colIndices = S.fromList (map snd (M.keys new_vals)) | |
in SM {bounds=bounds, rowIndices=new_rowIndices, colIndices=new_colIndices, vals=new_vals} | |
addSM :: (Eq a, Num a) => SparseMatrix a -> SparseMatrix a -> SparseMatrix a | |
addSM sm1 sm2 | |
| bounds sm1 /= bounds sm2 = error "matrices are not compatible" | |
| otherwise = | |
let raw_add = M.unionWith (+) (vals sm1) (vals sm2) in | |
let no_zeros = M.filter (/= 0) raw_add in | |
let new_rowIndices = S.fromList (map fst (M.keys no_zeros)) | |
new_colIndices = S.fromList (map snd (M.keys no_zeros)) | |
in SM {bounds=bounds sm1, rowIndices=new_rowIndices, colIndices=new_colIndices, vals=no_zeros} | |
negateSM :: (Eq a, Num a) => SparseMatrix a -> SparseMatrix a | |
negateSM sm1 = | |
let new_vals = M.map (* (-1)) (vals sm1) in | |
SM { bounds=bounds sm1, rowIndices=rowIndices sm1, colIndices=colIndices sm1, vals=new_vals} | |
subSM :: (Eq a, Num a) => SparseMatrix a -> SparseMatrix a -> SparseMatrix a | |
subSM sm1 sm2 = addSM sm1 (negateSM sm2) | |
getSM :: Num a => SparseMatrix a -> (Integer, Integer) -> a | |
getSM sm@(SM (num_rows, num_cols) rowIndices colIndices vals) (row, column) | |
| row > num_rows || row < 1 || column > num_cols || column < 1 = error "out of bounds" | |
| isJust (S.lookupIndex row rowIndices) && isJust (S.lookupIndex column colIndices) = fromJust (M.lookup (row, column) vals) | |
| otherwise = 0 | |
rowsSM :: SparseMatrix a -> Integer | |
rowsSM (SM (rows, _) _ _ _) = rows | |
colsSM :: SparseMatrix a -> Integer | |
colsSM (SM (_, cols) _ _ _) = cols | |
(<|+|>) :: (Eq a, Num a) => SparseMatrix a -> SparseMatrix a -> SparseMatrix a | |
sm1 <|+|> sm2 = addSM sm1 sm2 | |
(<|-|>) :: (Eq a, Num a) => SparseMatrix a -> SparseMatrix a -> SparseMatrix a | |
sm1 <|-|> sm2 = subSM sm1 sm2 | |
(<|*|>) :: (Eq a, Num a, Show a) => SparseMatrix a -> SparseMatrix a -> SparseMatrix a | |
sm1 <|*|> sm2 = mulSM sm1 sm2 | |
(<!>) :: (Num a) => SparseMatrix a -> (Integer, Integer) -> a | |
sm1 <!> sm2 = getSM sm1 sm2 | |
-- | Matrix of zero size with no values | |
emptySparseMatrix :: SparseMatrix a | |
emptySparseMatrix = SM (0,0) S.empty S.empty M.empty | |
isZeroVec :: SparseVector a -> Bool | |
isZeroVec = M.null . vec | |
-- returns row at index | |
row :: (Num a, Eq a) => SparseMatrix a -> Integer -> SparseVector a | |
row m i = SV (colsSM m) (M.filter (/= 0) (M.fromList (map (\((row, col), val) -> (col, val)) (M.toList (vals m))))) | |
-- | Fills row with zeroes (i.e. deletes it, but size of matrix doesn't change) | |
deleteRow :: (Num a, Eq a) => Integer -> SparseMatrix a -> SparseMatrix a | |
deleteRow i m = | |
let new_val = M.filter (/= 0) (M.mapWithKey (\(row, col) val -> if row == i then 0 else val) (vals m)) in | |
let new_rowIndices = S.fromList (map fst (M.keys new_val)) | |
new_colIndices = S.fromList (map snd (M.keys new_val)) | |
in | |
m { rowIndices = new_rowIndices, colIndices = new_colIndices, vals = new_val } | |
-- | Deletes element of vector at given index (size of vector doesn't change) | |
deleteVectorElem :: (Num a) => SparseVector a -> Integer-> SparseVector a | |
deleteVectorElem v j = v { vec = M.delete j (vec v) } | |
-- update values in row with a given function | |
updateRow :: (Num a, Eq a) => (SparseVector a -> SparseVector a) -> Integer -> SparseMatrix a -> SparseMatrix a | |
updateRow f i m = | |
let f' = vec . f . SV (colsSM m) in | |
let new_vec = f' (M.filter (/= 0) (M.fromList (map (\((row, col), val) -> (col, val)) (M.toList (vals m))))) in | |
let new_val = M.filter (/= 0) (M.fromList (map (\(col, val) -> ((i, col), val)) (M.toList new_vec))) in | |
let new_colIndices = S.fromList (map snd (M.keys new_val)) in | |
m { colIndices = new_colIndices, vals = new_val } | |
-- deletes element at given index | |
deleteElem :: (Num a, Eq a) => SparseMatrix a -> (Integer, Integer) -> SparseMatrix a | |
deleteElem m (i, j) = if isZeroVec (m' `row` i) | |
then deleteRow i m' | |
else m' | |
where m' = updateRow (`deleteVectorElem` j) i m | |
-- insert and replace element | |
ins :: (Num a, Eq a) => SparseMatrix a -> ((Integer, Integer), a) -> SparseMatrix a | |
ins m ((i, j), 0) = deleteElem m (i, j) | |
ins m ((i, j), x) = | |
let new_val = M.filter (/= 0) (M.insert (i, j) x (vals m)) in | |
let new_rowIndices = S.fromList (map fst (M.keys new_val)) | |
new_colIndices = S.fromList (map snd (M.keys new_val)) | |
in | |
m { rowIndices = new_rowIndices, colIndices = new_colIndices, vals = new_val } | |
-- returns the transposed matrix | |
transpose :: (Num a, Eq a) => SparseMatrix a -> SparseMatrix a | |
transpose m@(SM (b_row, b_col) _ _ _) = | |
let new_val = M.fromList $ map (\((row, col), val) -> ((col, row), val)) $ M.toList (vals m) in | |
let new_rowIndices = S.fromList (map fst (M.keys new_val)) | |
new_colIndices = S.fromList (map snd (M.keys new_val)) | |
in | |
m { bounds = (b_col, b_row), rowIndices = new_rowIndices, colIndices = new_colIndices, vals = new_val } | |
-- dot product of two maps | |
dotMap :: (Num a, Eq a) => M.Map Integer a -> M.Map Integer a -> a | |
dotMap v w = case M.foldl' (+) 0 $ M.intersectionWith (*) v w of | |
0 -> 0 | |
x -> x | |
-- the dot product of two SparceVectors | |
dot :: (Eq a, Num a) => SparseVector a -> SparseVector a -> a | |
dot v w = dotMap (vec v) (vec w) | |
-- matrix x vector multiplication | |
mulMV :: (Num a, Eq a, Show a) => SparseMatrix a -> SparseVector a -> SparseVector a | |
mulMV (SM (h,_) _ _ m) (SV _ v) = | |
let svec = dotMap v (M.fromList (map (\((row, col), val) -> (col, val)) (M.toList m))) in -- TODO fix this too | |
SV h (trace ("svec = " ++ show svec) M.empty) | |
mulSM :: (Num a, Eq a, Show a) => SparseMatrix a -> SparseMatrix a -> SparseMatrix a | |
mulSM first_mat second_mat = | |
let d = (rowsSM first_mat, colsSM second_mat) | |
bt = colIndices second_mat | |
m = M.filterWithKey (\(row, col) val -> S.member col bt) (vals first_mat) | |
svector = map (\((row, col), val) -> ()) (M.toList m) -- TODO FIX THIS | |
new_vals = M.filter (/= 0) m | |
new_rowIndices = S.fromList (map fst (M.keys new_vals)) | |
new_colIndices = S.fromList (map snd (M.keys new_vals)) | |
in SM d new_rowIndices new_colIndices m |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment