Skip to content

Instantly share code, notes, and snippets.

@jiahao
Last active December 4, 2019 05:52
Show Gist options
  • Save jiahao/8aaaa6e57b73396b7351ad7875436f24 to your computer and use it in GitHub Desktop.
Save jiahao/8aaaa6e57b73396b7351ad7875436f24 to your computer and use it in GitHub Desktop.
Iterate over SparseMatrixCSC stored entries. Implements Julia's new iterator protocol (new as of v0.7) https://julialang.org/blog/2018/07/iterators-in-julia-0.7
# Iterate over SparseMatrixCSC stored entries, ignoring stored zeros and
# missing values.
#
# Implements Julia's new iterator protocol (new as of v0.7)
# Ref: https://julialang.org/blog/2018/07/iterators-in-julia-0.7
#
# Jiahao Chen 2019-12-03
#
# MIT License available upon request
#
# Implementation note:
#
# This version of the iterator skips over stored zero and missing values.
# The comments tell you what to change if you don't want this behavior.
#
########################################################################
#
# The implementation
#
########################################################################
using SparseArrays
struct StoredEntries{T}
A::T
end
# Count how many stored values are not zero or missing
#
# If you don't need to check the values, you can simply use
# Base.length(A::StoredEntries{SparseMatrixCSC}) = length(A.A.nzval)
function Base.length(A::StoredEntries{SparseMatrixCSC{Tv,Ti}}) where {Tv,Ti}
n = 0
for (i, v) in enumerate(A.A.nzval)
if ismissing(v) || iszero(v)
continue
end
n += 1
end
n
end
# The return type of the iteration state is
# (i, j, v) - the ordinary 2D index, and the matrix element
Base.eltype(::StoredEntries{SparseMatrixCSC{Tv,Ti}}) where {Tv,Ti} =
Tuple{Ti,Ti,Tv}
# The state is a pair of integers:
# - the column index, j
# - the internal pointer, k, being the current index of a stored value
function Base.iterate(A::StoredEntries{SparseMatrixCSC{Tv,Ti}}, state=(1,1)) where {Tv,Ti}
j, k = state
# Check if iterator should terminate
if k >= length(A.A.nzval)
return # nothing - signals the end of iteration
end
# Read the next value that is not missing or zero
v = A.A.nzval[k]
while ismissing(v) || iszero(v)
k += 1
v = A.A.nzval[k]
end
# Advance the column index
while k >= A.A.colptr[j+1]
j += 1
end
# Read row index from internal storage
i = A.A.rowval[k]
# Return the index pair, matrix element, and the state for the next
# iteration
return (i, j, v), (j, k+1)
end
########################################################################
#
# A simple example program of usage
#
########################################################################
let
# Construct a random sparse matrix
m = 1000
n = 100
d = 0.4 # density
M = sprandn(m, n, d)
sum = 0.0
@inbounds for (_, _, v) in StoredEntries(M)
sum += v
end
@info "sum = $sum"
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment