Created
February 7, 2022 22:12
-
-
Save rejuvyesh/0c0995ac81d8c75efada7797a292f611 to your computer and use it in GitHub Desktop.
DLPack reproduce segfault
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using PyCall | |
using DLPack | |
using Test | |
using Zygote | |
using ChainRulesCore | |
torch = pyimport("torch") | |
functorch = pyimport("functorch") | |
dlpack = pyimport("torch.utils.dlpack") | |
py""" | |
def buffer_implicit(fn, buffers): | |
def newfn(params, inputs): | |
return fn(params, buffers, inputs) | |
return newfn | |
""" | |
pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject | |
pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject | |
reversedims(x::AbstractArray{T,N}) where {T,N} = permutedims(x, N:-1:1) | |
function ReverseDimsArray(a::AbstractArray{T,N}) where {T<:AbstractFloat,N} | |
PermutedDimsArray(a, N:-1:1) | |
end | |
struct TorchModuleWrapper | |
torch_stateless_module::PyObject | |
dtype::PyObject | |
device::PyObject | |
params::Tuple | |
buffers::Tuple | |
end | |
function TorchModuleWrapper(torch_module) | |
pybuiltin("isinstance")(torch_module, torch.nn.Module) || error("Not a torch.nn.Module") | |
device = torch.device("cpu") | |
funmod, params, buffers = functorch.make_functional_with_buffers(torch_module) | |
dtype = params[1].dtype | |
# TODO: shouldn't requrei reversedims | |
# Ideally should not even require conversion to array, it's already DLPack | |
jlparams = map(params) do x | |
reversedims(Array(DLArray(x, pyto_dlpack))) | |
end | |
return TorchModuleWrapper(funmod, dtype, device, jlparams, buffers) | |
end | |
maybecontiguous(x::AbstractArray) = Array(x) | |
mayebecontiguous(x::StridedArray) = x | |
function (wrap::TorchModuleWrapper)(args...) | |
# TODO: handle multiple outputs | |
params = wrap.params | |
tensor_out = wrap.torch_stateless_module(Tuple(map(x -> DLPack.share((x), pyfrom_dlpack).requires_grad_(true), params)), | |
wrap.buffers, map(x -> DLPack.share((x), pyfrom_dlpack), args)...) | |
res = ReverseDimsArray(DLArray(tensor_out, pyto_dlpack)) | |
return res | |
end | |
function ChainRulesCore.rrule(wrap::TorchModuleWrapper, args...) | |
params = wrap.params | |
torch_primal, torch_vjpfun = functorch.vjp(py"buffer_implicit"(wrap.torch_stateless_module, wrap.buffers), Tuple(map(x -> DLPack.share((x), pyfrom_dlpack).requires_grad_(true), params)), | |
map(x -> DLPack.share((x), pyfrom_dlpack).to(dtype = wrap.dtype, device = wrap.device).requires_grad_(true), args)...) | |
project = ProjectTo(args) | |
function TorchModuleWrapper_pullback(Δ) | |
torch_tangent_vals = torch_vjpfun(DLPack.share((maybecontiguous(Δ)), pyfrom_dlpack)) | |
jlparams_tangents = map(x -> ReverseDimsArray(DLArray(x, pyto_dlpack)), torch_tangent_vals[1]) | |
args_tangents = project(map(x -> ReverseDimsArray(DLArray(x, pyto_dlpack)), torch_tangent_vals[2:end])) | |
return (Tangent{TorchModuleWrapper}(; torch_stateless_module = NoTangent(), dtype = NoTangent(), device = NoTangent(), params = jlparams_tangents, buffers = NoTangent()), args_tangents...) | |
end | |
res = ReverseDimsArray(DLArray(torch_primal, pyto_dlpack)) | |
return res, TorchModuleWrapper_pullback | |
end | |
batchsize = 1 | |
indim = 3 | |
outdim = 2 | |
hiddendim = 4 | |
function compare_grad_wrt_params(modelwrap, inputs...) | |
params = map(x -> torch.as_tensor(copy(ReverseDimsArray(x))).to(device = modelwrap.device, dtype = modelwrap.dtype).requires_grad_(true), (modelwrap.params)) | |
torch_out = modelwrap.torch_stateless_module(params, modelwrap.buffers, map(z->torch.as_tensor(PyReverseDims(copy(z))).to(dtype=modelwrap.dtype), inputs)...).sum() | |
torchgrad = map(x-> ReverseDimsArray(x.numpy()), torch.autograd.grad(torch_out, params)) | |
grad, = Zygote.gradient(m->sum(m(inputs...)), modelwrap) | |
@test length(torchgrad) == length(grad.params) | |
for i in 1:length(grad.params) | |
@test isapprox(torchgrad[i], grad.params[i]) | |
end | |
@test length(grad.params) == length(modelwrap.params) | |
@test grad.params[1] !== nothing | |
@test grad.params[2] !== nothing | |
@test size(grad.params[1]) == size(modelwrap.params[1]) | |
@test size(grad.params[2]) == size(modelwrap.params[2]) | |
end | |
lin = torch.nn.Linear(indim, outdim) | |
torchparams = Tuple([copy(DLArray(p, pyto_dlpack)) for p in lin.parameters()]) # (outdim, indim), (outdim,)), | |
linwrap = TorchModuleWrapper(lin) | |
x = randn(Float32, indim, batchsize) | |
y = linwrap(x) | |
compare_grad_wrt_params(linwrap, deepcopy(x)) |
julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
┌ Warning: `vendor()` is deprecated, use `BLAS.get_config()` and inspect the output instead
│ caller = npyinitialize() at numpy.jl:67
└ @ PyCall ~/.julia/packages/PyCall/L0fLP/src/numpy.jl:67
Test Passed
Expression: size(grad.params[2]) == size(modelwrap.params[2])
Evaluated: (2,) == (2,)
julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
Test Passed
Expression: size(grad.params[2]) == size(modelwrap.params[2])
Evaluated: (2,) == (2,)
julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
Test Passed
Expression: size(grad.params[2]) == size(modelwrap.params[2])
Evaluated: (2,) == (2,)
julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
error in running finalizer: ReadOnlyMemoryError()
@pyglobalobj at /home/jagupt/.julia/packages/PyCall/L0fLP/src/startup.jl:145
unknown function (ip: (nil))
error in running finalizer: ReadOnlyMemoryError()
@pyglobalobj at /home/jagupt/.julia/packages/PyCall/L0fLP/src/startup.jl:145
unknown function (ip: 0x961e6ff)
Test Passed
Expression: size(grad.params[2]) == size(modelwrap.params[2])
Evaluated: (2,) == (2,)
julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
Test Passed
Expression: size(grad.params[2]) == size(modelwrap.params[2])
Evaluated: (2,) == (2,)
julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
error in running finalizer: ReadOnlyMemoryError()
@pyglobalobj at /home/jagupt/.julia/packages/PyCall/L0fLP/src/startup.jl:145
unknown function (ip: 0x7f2fa3f509bf)
signal (11): Segmentation fault
in expression starting at /home/jagupt/.julia/dev/PyCallChainRules/test/stresstest_dlpack.jl:104
jl_system_image_data at /home/jagupt/.julia/juliaup/julia-1.7.1+0~x64/lib/julia/sys.so (unknown line)
Allocations: 64202361 (Pool: 64181938; Big: 20423); GC: 67
Does changing the first two lines by
using Distributed
@everywhere using PyCall
@everywhere using DLPack
fixes the issue?
EDIT: No, I don't think it does.
Yep, it doesn't fix the issue. Although would love to understand your train of thought that led to this test!
julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0
┌ Warning: `vendor()` is deprecated, use `BLAS.get_config()` and inspect the output instead
│ caller = npyinitialize() at numpy.jl:67
└ @ PyCall ~/.julia/packages/PyCall/L0fLP/src/numpy.jl:67
Test Passed
Expression: size(grad.params[2]) == size(modelwrap.params[2])
Evaluated: (2,) == (2,)
julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0
Test Passed
Expression: size(grad.params[2]) == size(modelwrap.params[2])
Evaluated: (2,) == (2,)
julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0
Test Passed
Expression: size(grad.params[2]) == size(modelwrap.params[2])
Evaluated: (2,) == (2,)
julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0
signal (11): Segmentation fault
in expression starting at /home/jagupt/.julia/dev/PyCallChainRules/test/stresstest_dlpack.jl:107
unknown function (ip: 0x7f39d7588580)
Allocations: 62960043 (Pool: 62939873; Big: 20170); GC: 69
Ok. I think I found the culprit. Going over the pytorch repo I found that the DLManagedTensor
is captured https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/DLConvertor.cpp#L243, so we need not only to keep the array around, but also the tensor. I'll update the PR.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Can you add
@show DLPack.PYCALL_NOOP_DELETER
just after loading the packages and then share the output of this when it segfaults?