Created
February 18, 2020 20:40
-
-
Save Roger-luo/f2bfe56d882c06e9905fad8e4e1cf826 to your computer and use it in GitHub Desktop.
performance regression of Tracker in Zygote: a MPS case
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| using TNFilters | |
| using Flux | |
| using Zygote | |
| using BenchmarkTools | |
| using TNFilters: bmm, bmm!, batched_tr | |
| using Flux: params | |
| using Zygote: AContext, Context, _pullback, cache, accum_param | |
| # I have not found why yet, the manual generated result roughly gives the same performance as Tracker (Tracker is about 35μs) | |
| mps = MPS{10}(ntuple(_->2, 36+1)) do l, r, B | |
| rand(l, r, B) | |
| end | |
| S = rand(1:2, 36) | |
| ps = params(mps) | |
| cx = Context() | |
| for p in ps | |
| cache(cx)[p] = nothing | |
| end | |
| getproperty(mps, :tensors) | |
| Zygote.@adjoint getproperty(mps::MPS, f) = getproperty(mps, f), Δ -> nothing | |
| Zygote.@adjoint getfield(mps::MPS, f) = getfield(mps, f), Δ -> nothing | |
| # Zygote generated adjoint, (backward) is about 500μs | |
| _, back = _pullback(cx, mps, S) | |
| @benchmark back($(rand(10))) | |
| # manual generate the adjoint,(backward) is about 45μs | |
| Zygote.@adjoint function (op::MPS)(configs::AbstractVector) | |
| @boundscheck length(configs) == length(op) || error("length of the configuration does not match mps length") | |
| output = @inbounds op.tensors[1][configs[1]] | |
| stack = Vector{Any}(undef, length(op) - 1) | |
| @inbounds for i in 2:length(op) | |
| output, back = _pullback(__context__, bmm, output, op.tensors[i][configs[i]]) | |
| stack[i-1] = back | |
| end | |
| y, tr_back = _pullback(batched_tr, output) | |
| function back(Δ) | |
| _, Δ = tr_back(Δ) | |
| grads = [] | |
| for i in length(op):-1:2 | |
| _, Δ, grad_tensor = stack[i-1](Δ) | |
| accum_param(__context__, op.tensors[i][configs[i]], grad_tensor) | |
| push!(grads, grad_tensor) | |
| end | |
| accum_param(__context__, op.tensors[1][configs[1]], Δ) | |
| push!(grads, Δ) | |
| return (;tensors=Tuple(grads)), nothing | |
| end | |
| y, back | |
| end | |
| _, back = _pullback(cx, mps, S) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment