Created
June 22, 2020 17:11
-
-
Save yiyuezhuo/93d632a2fc146a38bbfbed0bc64254eb to your computer and use it in GitHub Desktop.
ImageTrackingAD
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"using Zygote\n", | |
"using Images\n", | |
"using VideoIO\n", | |
"using ImageTracking\n", | |
"using FiniteDiff\n", | |
"\n", | |
"using Zygote: @adjoint" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(PermutedDimsArray{RGB{Normed{UInt8,8}},2,(2, 1),(2, 1),Array{RGB{Normed{UInt8,8}},2}}, (128, 228))" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"source = \"/home/yiyuezhuo/windows/d/datasets/stove2/SVD_dummy_4/video_sample_free/2正常/9219847_22.mp4\"\n", | |
"\n", | |
"r = VideoIO.openvideo(source)\n", | |
"img1 = read(r)\n", | |
"img2 = read(r)\n", | |
"close(r)\n", | |
"\n", | |
"img1_gray = img1 .|> Gray{Float32}\n", | |
"img2_gray = img2 .|> Gray{Float32}\n", | |
"\n", | |
"img1 |> typeof, img1 |> size" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(Array{Float32,3}, (3, 128, 228))" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"img1_arr = img1 |> channelview .|> Float32\n", | |
"img2_arr = img2 |> channelview .|> Float32\n", | |
"w_mat = randn(eltype(img1_arr), size(img1_arr)...)\n", | |
"\n", | |
"w_mat |> typeof, w_mat |> size" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(Array{StaticArrays.SArray{Tuple{2},Float64,1,2},2}, (128, 228))" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"algorithm = Farneback(50, estimation_window = 11,\n", | |
" σ_estimation_window = 9.0,\n", | |
" expansion_window = 6,\n", | |
" σ_expansion_window = 5.0)\n", | |
"\n", | |
"flow = optical_flow(img1_gray, img2_gray, algorithm)\n", | |
"\n", | |
"flow |> typeof, flow |> size" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(Array{Float32,3}, (2, 128, 228))" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"flow_arr = [flow[i,j][k] for k in 1:2, i in 1:size(flow, 1), j in 1:size(flow,2)] .|> Float32\n", | |
"flow_arr |> typeof, flow_arr |> size" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"get_left_right_weight (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"function get_left_right_weight(val)\n", | |
" left, right = floor(Int, val), ceil(Int, val)\n", | |
" if left == right\n", | |
" return left, right, 1., 0.\n", | |
" end\n", | |
" left_weight = 1. - val + left\n", | |
" right_weight = 1. - right + val\n", | |
" return left, right, left_weight, right_weight\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"struct BackTracker{T}\n", | |
" i::Int\n", | |
" j::Int\n", | |
" w::T\n", | |
" di::Vector{T}\n", | |
" dj::Vector{T}\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"apply_flow_2 (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"function apply_flow_2(img1_c, flow)\n", | |
" img2_t = zero(img1_c)\n", | |
" back_tracker_mat = [BackTracker[] for i in 1:size(img1_c, 2), j in 1:size(img1_c, 3)]\n", | |
" \n", | |
" for i in 1:size(img1_c, 2), j in 1:size(img1_c, 3)\n", | |
" bottom, top, bottom_weight, top_weight = get_left_right_weight(i + flow[1, i, j])\n", | |
" left, right, left_weight, right_weight = get_left_right_weight(j + flow[2, i, j])\n", | |
"\n", | |
" if (left < 1) | (left > size(img1_c, 3))\n", | |
" left_weight = 0.\n", | |
" end\n", | |
" if (bottom < 1) | (bottom > size(img1_c, 2))\n", | |
" bottom_weight = 0.\n", | |
" end\n", | |
" if (right < 1) | (right > size(img1_c, 3))\n", | |
" right_weight = 0.\n", | |
" end\n", | |
" if (top < 1) | (top > size(img1_c, 2))\n", | |
" top_weight = 0.\n", | |
" end\n", | |
"\n", | |
" right_top_weight = right_weight * top_weight\n", | |
" right_bottom_weight = right_weight * bottom_weight\n", | |
" left_top_weight = left_weight * top_weight\n", | |
" left_bottom_weight = left_weight * bottom_weight\n", | |
"\n", | |
" for (w, v, c_di, c_dj) in zip(\n", | |
" [right_top_weight, right_bottom_weight, left_top_weight, left_bottom_weight],\n", | |
" [[top, right], [bottom, right], [top, left], [bottom, left]],\n", | |
" [right_weight, -right_weight, left_weight, -left_weight],\n", | |
" [top_weight, bottom_weight, -top_weight, -bottom_weight])\n", | |
" if w > 0\n", | |
" ov = img1_c[:, v[1], v[2]]\n", | |
" img2_t[:, i, j] += w * ov\n", | |
" bt = BackTracker(v[1], v[2], w, ov * c_di, ov * c_dj)\n", | |
" push!(back_tracker_mat[i,j], bt)\n", | |
" end\n", | |
" end\n", | |
" end\n", | |
" img2_t, back_tracker_mat\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"apply_flow_1 (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"apply_flow_1(img1_c, flow) = apply_flow_2(img1_c, flow)[1]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"∇apply_flow_1 (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"function ∇apply_flow_1(grad, back_tracker_mat)\n", | |
" grad_i = zero(grad) |> collect\n", | |
" grad_flow = zeros(eltype(grad), 2, size(grad, 2), size(grad, 3))\n", | |
" \n", | |
" for i in 1:size(grad, 2), j in 1:size(grad, 3)\n", | |
" tracker_list = back_tracker_mat[i, j]\n", | |
" for tracker in tracker_list\n", | |
" grad_i[:, tracker.i, tracker.j] += grad[:, i, j] * tracker.w\n", | |
" # flow di grad\n", | |
" grad_flow[1, i, j] += sum(grad[:, i, j] .* tracker.di) \n", | |
" # flow dj grad\n", | |
" grad_flow[2, i, j] += sum(grad[:, i, j] .* tracker.dj) \n", | |
" end\n", | |
" end\n", | |
" grad_i, grad_flow\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@adjoint function apply_flow_1(img1_arr, flow_arr)\n", | |
" img2_t, back_tracker_mat = apply_flow_2(img1_arr, flow_arr)\n", | |
" back(grad) = ∇apply_flow_1(grad, back_tracker_mat)\n", | |
" return img2_t, back\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(Float32[1.9060305 -1.3470834 … -0.859872 -0.93491775; -0.05513298 -0.65744525 … -0.122666925 0.34732682; 0.47509822 -0.03135249 … -0.7193032 1.2011641]\n", | |
"\n", | |
"Float32[-0.7354938 -0.3852348 … -0.022246383 0.36421964; 1.0123888 1.1131711 … -1.2337879 -1.2176565; 0.75340176 0.10042812 … -1.2375505 -0.25818032]\n", | |
"\n", | |
"Float32[-1.0146208 0.66275 … 0.3624875 0.50923115; -0.108753026 -0.3002437 … 0.011126457 -1.07455; -1.3513801 -0.1486916 … 1.691757 1.4860505]\n", | |
"\n", | |
"...\n", | |
"\n", | |
"Float32[-0.41907698 -0.35688752 … -0.094783634 -0.8622765; -1.1904081 1.0911273 … 0.080233306 0.6655891; 1.0631945 1.119349 … 1.9679961 -0.7786205]\n", | |
"\n", | |
"Float32[-0.1960498 -1.4646168 … -0.8551477 0.6595207; -1.2361741 0.40957487 … 1.5015283 0.95523006; -0.830574 0.25330108 … 0.11482439 -0.5030325]\n", | |
"\n", | |
"Float32[-0.17120354 -1.2233598 … 0.7954573 -0.5615203; -0.6241249 0.69396955 … 0.09401224 -0.19023323; 1.0751172 1.9486365 … -0.14907522 0.9777236], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", | |
"\n", | |
"Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", | |
"\n", | |
"Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", | |
"\n", | |
"...\n", | |
"\n", | |
"Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", | |
"\n", | |
"Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", | |
"\n", | |
"Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0])" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grad_i, grad_flow = gradient(img1_arr, flow_arr) do img1_arr, flow_arr\n", | |
" sum(apply_flow_1(img1_arr, flow_arr) .* w_mat)\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"2-element Array{Float32,1}:\n", | |
" 0.084884696\n", | |
" 0.32596415" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grad_flow[:,10,10]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.0858306884765625" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"flow_arr_0 = copy(flow_arr)\n", | |
"flow_arr_1 = copy(flow_arr)\n", | |
"flow_arr_1[1,10,10] += 1e-3\n", | |
"\n", | |
"(sum(apply_flow_1(img1_arr, flow_arr_1) .* w_mat) - sum(apply_flow_1(img1_arr, flow_arr_0) .* w_mat) ) / 1e-3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"-0.5035400390625" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"img1_arr_0 = copy(img1_arr)\n", | |
"img1_arr_1 = copy(img1_arr)\n", | |
"img1_arr_1[1,10,10] += 1e-3\n", | |
"\n", | |
"(sum(apply_flow_1(img1_arr_1, flow_arr) .* w_mat) - sum(apply_flow_1(img1_arr_0, flow_arr) .* w_mat) ) / 1e-3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"-0.5043076f0" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grad_i[1,10,10]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Julia 1.4.1", | |
"language": "julia", | |
"name": "julia-1.4" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "1.4.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment