Last active
February 26, 2022 12:51
-
-
Save dermesser/3567a84919417c70b3b1abf12d043037 to your computer and use it in GitHub Desktop.
LazyLoader: A Julia utility to lazily load images from disk for use in e.g. Flux for machine learning. Can be used as replacement for Flux.DataLoader if the dataset is too large for memory.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Glob | |
import CUDA: CuArray | |
import Images | |
import Flux: gpu | |
import Flux | |
using Dates | |
function logn(args...) | |
println(now(), " ", args...) | |
end | |
struct LazyLoader{T} <: AbstractArray{Array{T},1} | |
pattern::String | |
# Encodes a vector of filenames into a OneHotMatrix | |
labelfunc::Function | |
batchsize::Int | |
channels::Int | |
resize::Tuple{Int,Int} | |
paths::Vector{String} | |
end | |
function LazyLoader{PixelType}(pattern::String, labelfunc::Function, batchsize::Int=16, resize::Tuple{Int,Int}=(64,64), channels::Int=3)::LazyLoader{PixelType} where {PixelType} | |
files = Glob.glob(pattern) | |
LazyLoader{PixelType}(pattern, labelfunc, batchsize, channels, resize, files) | |
end | |
function Base.iterate(ll::LazyLoader{T}, state::Int=1)::Union{Nothing, Tuple{Tuple{Union{Array,CuArray}, Flux.OneHotMatrix}, Int}} where {T, A, LabelT} | |
if state+ll.batchsize > length(ll) | |
return nothing | |
end | |
logn("Loading images $state + $(ll.batchsize)") | |
data = zeros(T, ll.resize[1], ll.resize[2], ll.channels, ll.batchsize) | |
j = 1 | |
pj = 1 | |
fns = [] | |
while j <= ll.batchsize | |
img = ll[state+pj] | |
if isnothing(img) | |
pj += 1 | |
continue | |
end | |
data[:,:,:,j] .= img | |
push!(fns, ll.paths[state+pj]) | |
j += 1 | |
pj += 1 | |
end | |
labels = ll.labelfunc(fns) | |
((data |> gpu, labels |> gpu), state+pj-1) | |
end | |
function Base.size(ll::LazyLoader{T}) where {T} | |
Base.length(ll.paths) | |
end | |
function Base.repr(ll::LazyLoader{T})::String where {T} | |
string("LazyLoader{$T} with $(length(ll.paths)) files ($(ll.channels) channels) in $(ll.pattern)") | |
end | |
function Base.show(io::IO, ll::LazyLoader) | |
Base.write(io, Base.repr(ll)) | |
nothing | |
end | |
function Base.show(ll::LazyLoader) | |
Base.show(Base.stdout, ll) | |
end | |
# Override AbstractArray display | |
function Base.display(ll::LazyLoader) | |
Base.show(ll) | |
end | |
function Base.getindex(ll::LazyLoader{T}, i::Int)::Union{Nothing, Union{Array,CuArray}} where {T} | |
n = ll.paths[i] | |
#logn("Loading image $i: $n") | |
img = Images.load(n) | |
if length(img) < prod(ll.resize) | |
#logn("Image too short: $(length(img)) vs $(prod(ll.resize))") | |
return nothing | |
end | |
img = Images.imresize(img, ll.resize) | |
if ll.channels == 3 | |
return convert.(T, cat(Images.red.(img), Images.green.(img), Images.blue.(img), dims=3)) |> gpu | |
elseif ll.channels == 1 | |
return convert.(T, Images.Gray.(img)[:,:,:]) |> gpu | |
end | |
end | |
function Base.setindex!(ll::LazyLoader{T}, v, i::Int) where {T} | |
error("Cannot write to LazyLoader elements!") | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment