Created
August 16, 2019 20:39
-
-
Save Roger-luo/1dc955595eb40cef9a77b0e145ece153 to your computer and use it in GitHub Desktop.
Alloc.jl in Cassette
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module HoistMem | |
export hoist_alloc, Buffer | |
using Cassette, LinearAlgebra | |
using Cassette: @context, overdub | |
@context BuffCtx | |
mutable struct Buffer | |
buf::Vector{UInt8} | |
offset::UInt | |
end | |
Buffer(n::Int) = Buffer(Vector{UInt8}(undef, n), 0) | |
Base.copy(b::Buffer) = Buffer(copy(b.buf), b.offset) | |
Cassette.prehook(::BuffCtx, f, xs...) = nothing | |
function alloc(b::Buffer, ::Type{Array{T,N}}, d::NTuple{N,Int}) where {T,N} | |
# @info "Allocating $(prod(d)) * $(T)" | |
ptr = Base.unsafe_convert(Ptr{UInt8}, b.buf) + b.offset | |
b.offset += sizeof(T) * prod(d) | |
b.offset > length(b.buf) && error("Alloc: Out of memory") | |
unsafe_wrap(Array, convert(Ptr{T}, ptr), d) | |
end | |
function clear!(b::Buffer) | |
b.offset = 0 | |
return b | |
end | |
function hoist_alloc(f, b::Buffer) | |
clear!(b) | |
return overdub(BuffCtx(metadata=b), f) | |
end | |
for F in [bmm!, LinearAlgebra.mul!, TNFilters.batched_tr!, Base.promote_op, size, Base.to_shape, | |
Broadcast.broadcasted, Broadcast.instantiate, Broadcast.preprocess, | |
Broadcast.combine_eltypes, copyto!, Broadcast.copyto_nonleaf!, | |
Broadcast.axes, Base.getindex, Base.setindex!, Base.fill!] | |
@eval @inline Cassette.overdub(ctx::BuffCtx, f::typeof($F), xs...) = f(xs...) | |
end | |
@inline function Cassette.overdub(ctx::BuffCtx, ::typeof(similar), bc::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{N}}, ::Type{T}) where {T, N} | |
alloc(ctx.metadata, Array{T, N}, length.(Broadcast.axes(bc))) | |
end | |
@inline function Cassette.overdub(ctx::BuffCtx, ::Type{Array{T, N}}, ::UndefInitializer, d::Vararg{Int, N}) where {T, N} | |
return alloc(ctx.metadata, Array{T, N}, d) | |
end | |
export mprofile | |
@context ProfileCtx | |
# Cassette.prehook(cx::ProfileCtx, f::typeof(Core.apply_type), xs...) = nothing | |
function Cassette.prehook(cx::ProfileCtx, ::Type{Array{T, N}}, ::UndefInitializer, d::Vararg{Int, N}) where {T, N} | |
T === Any && return | |
cx.metadata[] += sizeof(T) * prod(d) | |
return | |
end | |
function mprofile(f) | |
ctx = ProfileCtx(metadata = Ref(0)) | |
x = overdub(ctx, f) | |
@info "allocated $(ctx.metadata[]) bytes" | |
return x | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment