Created
February 7, 2022 22:12
-
-
Save rejuvyesh/0c0995ac81d8c75efada7797a292f611 to your computer and use it in GitHub Desktop.
DLPack reproduce segfault
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using PyCall | |
using DLPack | |
using Test | |
using Zygote | |
using ChainRulesCore | |
torch = pyimport("torch") | |
functorch = pyimport("functorch") | |
dlpack = pyimport("torch.utils.dlpack") | |
py""" | |
def buffer_implicit(fn, buffers): | |
def newfn(params, inputs): | |
return fn(params, buffers, inputs) | |
return newfn | |
""" | |
pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject | |
pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject | |
reversedims(x::AbstractArray{T,N}) where {T,N} = permutedims(x, N:-1:1) | |
function ReverseDimsArray(a::AbstractArray{T,N}) where {T<:AbstractFloat,N} | |
PermutedDimsArray(a, N:-1:1) | |
end | |
struct TorchModuleWrapper | |
torch_stateless_module::PyObject | |
dtype::PyObject | |
device::PyObject | |
params::Tuple | |
buffers::Tuple | |
end | |
function TorchModuleWrapper(torch_module) | |
pybuiltin("isinstance")(torch_module, torch.nn.Module) || error("Not a torch.nn.Module") | |
device = torch.device("cpu") | |
funmod, params, buffers = functorch.make_functional_with_buffers(torch_module) | |
dtype = params[1].dtype | |
# TODO: shouldn't requrei reversedims | |
# Ideally should not even require conversion to array, it's already DLPack | |
jlparams = map(params) do x | |
reversedims(Array(DLArray(x, pyto_dlpack))) | |
end | |
return TorchModuleWrapper(funmod, dtype, device, jlparams, buffers) | |
end | |
maybecontiguous(x::AbstractArray) = Array(x) | |
mayebecontiguous(x::StridedArray) = x | |
function (wrap::TorchModuleWrapper)(args...) | |
# TODO: handle multiple outputs | |
params = wrap.params | |
tensor_out = wrap.torch_stateless_module(Tuple(map(x -> DLPack.share((x), pyfrom_dlpack).requires_grad_(true), params)), | |
wrap.buffers, map(x -> DLPack.share((x), pyfrom_dlpack), args)...) | |
res = ReverseDimsArray(DLArray(tensor_out, pyto_dlpack)) | |
return res | |
end | |
function ChainRulesCore.rrule(wrap::TorchModuleWrapper, args...) | |
params = wrap.params | |
torch_primal, torch_vjpfun = functorch.vjp(py"buffer_implicit"(wrap.torch_stateless_module, wrap.buffers), Tuple(map(x -> DLPack.share((x), pyfrom_dlpack).requires_grad_(true), params)), | |
map(x -> DLPack.share((x), pyfrom_dlpack).to(dtype = wrap.dtype, device = wrap.device).requires_grad_(true), args)...) | |
project = ProjectTo(args) | |
function TorchModuleWrapper_pullback(Δ) | |
torch_tangent_vals = torch_vjpfun(DLPack.share((maybecontiguous(Δ)), pyfrom_dlpack)) | |
jlparams_tangents = map(x -> ReverseDimsArray(DLArray(x, pyto_dlpack)), torch_tangent_vals[1]) | |
args_tangents = project(map(x -> ReverseDimsArray(DLArray(x, pyto_dlpack)), torch_tangent_vals[2:end])) | |
return (Tangent{TorchModuleWrapper}(; torch_stateless_module = NoTangent(), dtype = NoTangent(), device = NoTangent(), params = jlparams_tangents, buffers = NoTangent()), args_tangents...) | |
end | |
res = ReverseDimsArray(DLArray(torch_primal, pyto_dlpack)) | |
return res, TorchModuleWrapper_pullback | |
end | |
batchsize = 1 | |
indim = 3 | |
outdim = 2 | |
hiddendim = 4 | |
function compare_grad_wrt_params(modelwrap, inputs...) | |
params = map(x -> torch.as_tensor(copy(ReverseDimsArray(x))).to(device = modelwrap.device, dtype = modelwrap.dtype).requires_grad_(true), (modelwrap.params)) | |
torch_out = modelwrap.torch_stateless_module(params, modelwrap.buffers, map(z->torch.as_tensor(PyReverseDims(copy(z))).to(dtype=modelwrap.dtype), inputs)...).sum() | |
torchgrad = map(x-> ReverseDimsArray(x.numpy()), torch.autograd.grad(torch_out, params)) | |
grad, = Zygote.gradient(m->sum(m(inputs...)), modelwrap) | |
@test length(torchgrad) == length(grad.params) | |
for i in 1:length(grad.params) | |
@test isapprox(torchgrad[i], grad.params[i]) | |
end | |
@test length(grad.params) == length(modelwrap.params) | |
@test grad.params[1] !== nothing | |
@test grad.params[2] !== nothing | |
@test size(grad.params[1]) == size(modelwrap.params[1]) | |
@test size(grad.params[2]) == size(modelwrap.params[2]) | |
end | |
lin = torch.nn.Linear(indim, outdim) | |
torchparams = Tuple([copy(DLArray(p, pyto_dlpack)) for p in lin.parameters()]) # (outdim, indim), (outdim,)), | |
linwrap = TorchModuleWrapper(lin) | |
x = randn(Float32, indim, batchsize) | |
y = linwrap(x) | |
compare_grad_wrt_params(linwrap, deepcopy(x)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Ok. I think I found the culprit. Going over the pytorch repo I found that the
DLManagedTensor
is captured https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/DLConvertor.cpp#L243, so we need not only to keep the array around, but also the tensor. I'll update the PR.