Last active
December 4, 2019 05:52
-
-
Save jiahao/8aaaa6e57b73396b7351ad7875436f24 to your computer and use it in GitHub Desktop.
Iterate over SparseMatrixCSC stored entries. Implements Julia's new iterator protocol (new as of v0.7) https://julialang.org/blog/2018/07/iterators-in-julia-0.7
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Iterate over SparseMatrixCSC stored entries, ignoring stored zeros and | |
# missing values. | |
# | |
# Implements Julia's new iterator protocol (new as of v0.7) | |
# Ref: https://julialang.org/blog/2018/07/iterators-in-julia-0.7 | |
# | |
# Jiahao Chen 2019-12-03 | |
# | |
# MIT License available upon request | |
# | |
# Implementation note: | |
# | |
# This version of the iterator skips over stored zero and missing values. | |
# The comments tell you what to change if you don't want this behavior. | |
# | |
######################################################################## | |
# | |
# The implementation | |
# | |
######################################################################## | |
using SparseArrays | |
struct StoredEntries{T} | |
A::T | |
end | |
# Count how many stored values are not zero or missing | |
# | |
# If you don't need to check the values, you can simply use | |
# Base.length(A::StoredEntries{SparseMatrixCSC}) = length(A.A.nzval) | |
function Base.length(A::StoredEntries{SparseMatrixCSC{Tv,Ti}}) where {Tv,Ti} | |
n = 0 | |
for (i, v) in enumerate(A.A.nzval) | |
if ismissing(v) || iszero(v) | |
continue | |
end | |
n += 1 | |
end | |
n | |
end | |
# The return type of the iteration state is | |
# (i, j, v) - the ordinary 2D index, and the matrix element | |
Base.eltype(::StoredEntries{SparseMatrixCSC{Tv,Ti}}) where {Tv,Ti} = | |
Tuple{Ti,Ti,Tv} | |
# The state is a pair of integers: | |
# - the column index, j | |
# - the internal pointer, k, being the current index of a stored value | |
function Base.iterate(A::StoredEntries{SparseMatrixCSC{Tv,Ti}}, state=(1,1)) where {Tv,Ti} | |
j, k = state | |
# Check if iterator should terminate | |
if k >= length(A.A.nzval) | |
return # nothing - signals the end of iteration | |
end | |
# Read the next value that is not missing or zero | |
v = A.A.nzval[k] | |
while ismissing(v) || iszero(v) | |
k += 1 | |
v = A.A.nzval[k] | |
end | |
# Advance the column index | |
while k >= A.A.colptr[j+1] | |
j += 1 | |
end | |
# Read row index from internal storage | |
i = A.A.rowval[k] | |
# Return the index pair, matrix element, and the state for the next | |
# iteration | |
return (i, j, v), (j, k+1) | |
end | |
######################################################################## | |
# | |
# A simple example program of usage | |
# | |
######################################################################## | |
let | |
# Construct a random sparse matrix | |
m = 1000 | |
n = 100 | |
d = 0.4 # density | |
M = sprandn(m, n, d) | |
sum = 0.0 | |
@inbounds for (_, _, v) in StoredEntries(M) | |
sum += v | |
end | |
@info "sum = $sum" | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment