Skip to content

Instantly share code, notes, and snippets.

@Roger-luo
Created February 18, 2020 20:40
Show Gist options
  • Select an option

  • Save Roger-luo/f2bfe56d882c06e9905fad8e4e1cf826 to your computer and use it in GitHub Desktop.

Select an option

Save Roger-luo/f2bfe56d882c06e9905fad8e4e1cf826 to your computer and use it in GitHub Desktop.
performance regression of Tracker in Zygote: a MPS case
using TNFilters
using Flux
using Zygote
using BenchmarkTools
using TNFilters: bmm, bmm!, batched_tr
using Flux: params
using Zygote: AContext, Context, _pullback, cache, accum_param
# I have not found why yet, the manual generated result roughly gives the same performance as Tracker (Tracker is about 35μs)
mps = MPS{10}(ntuple(_->2, 36+1)) do l, r, B
rand(l, r, B)
end
S = rand(1:2, 36)
ps = params(mps)
cx = Context()
for p in ps
cache(cx)[p] = nothing
end
getproperty(mps, :tensors)
Zygote.@adjoint getproperty(mps::MPS, f) = getproperty(mps, f), Δ -> nothing
Zygote.@adjoint getfield(mps::MPS, f) = getfield(mps, f), Δ -> nothing
# Zygote generated adjoint, (backward) is about 500μs
_, back = _pullback(cx, mps, S)
@benchmark back($(rand(10)))
# manual generate the adjoint,(backward) is about 45μs
Zygote.@adjoint function (op::MPS)(configs::AbstractVector)
@boundscheck length(configs) == length(op) || error("length of the configuration does not match mps length")
output = @inbounds op.tensors[1][configs[1]]
stack = Vector{Any}(undef, length(op) - 1)
@inbounds for i in 2:length(op)
output, back = _pullback(__context__, bmm, output, op.tensors[i][configs[i]])
stack[i-1] = back
end
y, tr_back = _pullback(batched_tr, output)
function back(Δ)
_, Δ = tr_back(Δ)
grads = []
for i in length(op):-1:2
_, Δ, grad_tensor = stack[i-1](Δ)
accum_param(__context__, op.tensors[i][configs[i]], grad_tensor)
push!(grads, grad_tensor)
end
accum_param(__context__, op.tensors[1][configs[1]], Δ)
push!(grads, Δ)
return (;tensors=Tuple(grads)), nothing
end
y, back
end
_, back = _pullback(cx, mps, S)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment