Skip to content

Instantly share code, notes, and snippets.

@dermesser
Last active February 26, 2022 12:51
Show Gist options
  • Save dermesser/3567a84919417c70b3b1abf12d043037 to your computer and use it in GitHub Desktop.
Save dermesser/3567a84919417c70b3b1abf12d043037 to your computer and use it in GitHub Desktop.
LazyLoader: A Julia utility to lazily load images from disk for use in e.g. Flux for machine learning. Can be used as replacement for Flux.DataLoader if the dataset is too large for memory.
import Glob
import CUDA: CuArray
import Images
import Flux: gpu
import Flux
using Dates
function logn(args...)
println(now(), " ", args...)
end
struct LazyLoader{T} <: AbstractArray{Array{T},1}
pattern::String
# Encodes a vector of filenames into a OneHotMatrix
labelfunc::Function
batchsize::Int
channels::Int
resize::Tuple{Int,Int}
paths::Vector{String}
end
function LazyLoader{PixelType}(pattern::String, labelfunc::Function, batchsize::Int=16, resize::Tuple{Int,Int}=(64,64), channels::Int=3)::LazyLoader{PixelType} where {PixelType}
files = Glob.glob(pattern)
LazyLoader{PixelType}(pattern, labelfunc, batchsize, channels, resize, files)
end
function Base.iterate(ll::LazyLoader{T}, state::Int=1)::Union{Nothing, Tuple{Tuple{Union{Array,CuArray}, Flux.OneHotMatrix}, Int}} where {T, A, LabelT}
if state+ll.batchsize > length(ll)
return nothing
end
logn("Loading images $state + $(ll.batchsize)")
data = zeros(T, ll.resize[1], ll.resize[2], ll.channels, ll.batchsize)
j = 1
pj = 1
fns = []
while j <= ll.batchsize
img = ll[state+pj]
if isnothing(img)
pj += 1
continue
end
data[:,:,:,j] .= img
push!(fns, ll.paths[state+pj])
j += 1
pj += 1
end
labels = ll.labelfunc(fns)
((data |> gpu, labels |> gpu), state+pj-1)
end
function Base.size(ll::LazyLoader{T}) where {T}
Base.length(ll.paths)
end
function Base.repr(ll::LazyLoader{T})::String where {T}
string("LazyLoader{$T} with $(length(ll.paths)) files ($(ll.channels) channels) in $(ll.pattern)")
end
function Base.show(io::IO, ll::LazyLoader)
Base.write(io, Base.repr(ll))
nothing
end
function Base.show(ll::LazyLoader)
Base.show(Base.stdout, ll)
end
# Override AbstractArray display
function Base.display(ll::LazyLoader)
Base.show(ll)
end
function Base.getindex(ll::LazyLoader{T}, i::Int)::Union{Nothing, Union{Array,CuArray}} where {T}
n = ll.paths[i]
#logn("Loading image $i: $n")
img = Images.load(n)
if length(img) < prod(ll.resize)
#logn("Image too short: $(length(img)) vs $(prod(ll.resize))")
return nothing
end
img = Images.imresize(img, ll.resize)
if ll.channels == 3
return convert.(T, cat(Images.red.(img), Images.green.(img), Images.blue.(img), dims=3)) |> gpu
elseif ll.channels == 1
return convert.(T, Images.Gray.(img)[:,:,:]) |> gpu
end
end
function Base.setindex!(ll::LazyLoader{T}, v, i::Int) where {T}
error("Cannot write to LazyLoader elements!")
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment