Created
June 10, 2021 17:11
-
-
Save tbenst/f193cafe7855744e1d50ce56228dbbe2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import IterTools, Cairo | |
using Colors, Compose, Fontconfig, PyCall, StatsBase, Glob, | |
FileIO, Measures, Format, Unitful | |
pickle = pyimport("pickle") | |
## | |
@__DIR__ | |
## https://github.com/julia-vscode/julia-vscode/issues/2104 | |
base_dir = joinpath(@__DIR__, "2020-reconstructions/final-outputs") | |
base_dir = joinpath(@__DIR__, "final-outputs") | |
model_folders = filter(isdir, joinpath.(base_dir, readdir(base_dir))) | |
model_names = map(x->x[end],splitpath.(model_folders)) | |
NUM_IMAGES_PER_MODEL = 10 | |
order = ["originals", "linear_model", "64x64-mlp", | |
"64x64-mlp-small", "64x64-resnet-mlp", "64x64-convnet", "perceptual_model"] | |
idxs = map(o->searchsortedfirst(model_names,o), order) | |
model_names = model_names[idxs] | |
model_folders = model_folders[idxs] | |
# load perceptual losses | |
file = py"""open("each_image_pl.pickle", "rb")""" | |
percept_losses = pickle.load(file) | |
avg_percept_losses = Dict(k=>mean(values(v)) for (k,v) in percept_losses) | |
order = sortperm(map(x->x[2],collect(avg_percept_losses))) | |
println("perceptual model ranking") | |
@show collect(avg_percept_losses)[order] | |
## load MSE | |
se(x,y) = sum((x .- y) .^ 2) | |
function calc_MSE(model_name; original="originals", | |
data_folder=joinpath(@__DIR__,"final-outputs")) | |
model_path = joinpath(data_folder, model_name) | |
orig_path = joinpath(data_folder, original) | |
model_imgs = glob("test*.png", model_path) | |
img_names = map(x->x[end],splitpath.(model_imgs)) | |
cum_se = 0.0 | |
for name in img_names | |
# x = reinterpret(UInt8,load(joinpath(model_path,name))) | |
# y = reinterpret(UInt8,load(joinpath(orig_path,name))) | |
x = Float64.(load(joinpath(model_path,name))) | |
y = Float64.(load(joinpath(orig_path,name))) | |
cum_se += se(x,y) | |
end | |
cum_se/length(img_names) | |
end | |
# avg_mse = Dict(k=>calc_MSE(k) for k in keys(percept_losses)) | |
# order = sortperm(map(x->x[2],collect(avg_mse))) | |
# println("MSE model ranking") | |
# @show collect(avg_mse)[order] | |
avg_mse = Dict( | |
"8x8-convnet" => 79.3, | |
# "16x16-convnet" => 55.3, | |
"16x16-convnet" => 52.8, | |
# "32x32-convnet" => 48.3, | |
"32x32-convnet" => 38.2, | |
# "64x64-convnet" => 38.0, | |
"64x64-convnet" => 36.6, | |
"linear_model" => 42.3, | |
"64x64-mlp" => 41.3, | |
"64x64-mlp-small" => 38.4, | |
"64x64-resnet-mlp" => 33.0, | |
"32x32-convnet-descrambling" => 10.3, | |
"32x32-convnet-scrambled-targets" => 72.6, | |
"32x32-convnet-scrambled-both" => 88.8 | |
) | |
## | |
function drawMEA(N, everyN) | |
xs = collect(1:N) ./ (N+1) | |
grid = hcat(collect.(IterTools.product(xs, xs))[:]...) | |
grays = [] | |
reds = [] | |
for i in 1:N | |
for j in 1:N | |
if ((i-1) % everyN == 0) & ((j-1) % everyN == 0) | |
push!(reds, [i,j]) | |
else | |
push!(grays, [i,j]) | |
end | |
end | |
end | |
reds = hcat(reds...) ./ (N+1) | |
sz = 1 / (1.5 * N) | |
if length(grays) >= 1 | |
grays = hcat(grays...) ./ (N+1) | |
compose(context(), | |
(context(), rectangle(grays[1,:], grays[2,:], [sz], [sz]), fill("darkgray")), | |
(context(), rectangle(reds[1,:], reds[2,:], [sz], [sz]), fill("darkred")) | |
) | |
else | |
compose(context(), | |
(context(), rectangle(reds[1,:], reds[2,:], [sz], [sz]), fill("darkred")) | |
) | |
end | |
end | |
function centered_text(the_text, fs=7pt) | |
compose(context(), | |
text(0.5,0.5,the_text, hcenter, vcenter), | |
fontsize(fs)) | |
end | |
function get_images_for_model(model_folder) | |
img_names = readdir(model_folder) | |
# first 10 test images only | |
@assert NUM_IMAGES_PER_MODEL == 10 | |
img_names = img_names[occursin.(r"test[0-9]-.*", img_names)] | |
read.(joinpath.(model_folder, img_names)) | |
end | |
## FIGURE 1 MEA SIZE | |
ncol = 13 | |
nrow = 5 | |
W = 183mm | |
H = W/ncol * nrow | |
model_names = ["$(n)x$(n)-convnet" for n in [8, 16, 32, 64]] | |
tab = table(nrow, ncol, 1:nrow, 1:ncol) | |
im_start_col = 3 | |
tab[1,2] = [centered_text("channel\nsampling\n per 8x8")] | |
# draw for each # of active channels | |
for (i,everyN,mn) in zip(2:nrow, [8,4,2,1], model_names) | |
tab[i,2] = [compose(context(), drawMEA(8, everyN))] | |
model_folder = joinpath(@__DIR__, "final-outputs", mn) | |
images = get_images_for_model(model_folder) | |
for (idx,j) in enumerate(im_start_col:im_start_col+9) | |
tab[i,j] = [compose(context(), bitmap("image/png", images[idx],0,0,1,1))] | |
end | |
end | |
# original images | |
model_folder = joinpath(@__DIR__, "final-outputs", "originals") | |
images = get_images_for_model(model_folder) | |
# tab[6,1] = [centered_text("original")] | |
for (idx,j) in enumerate(im_start_col:im_start_col+9) | |
tab[1,j] = [compose(context(), bitmap("image/png", images[idx],0,0,1,1))] | |
end | |
# add MSE | |
tab[1, ncol-1] = [centered_text("MSE")] | |
for (idx,mn) in zip(2:5,model_names) | |
mse = "$(round(avg_mse[mn], digits=2))" | |
tab[idx,ncol-1] = [centered_text(mse)] | |
end | |
# add perceptual | |
tab[1, ncol] = [centered_text("Percept")] | |
for (idx,mn) in zip(2:5,model_names) | |
mse = "$(round(avg_percept_losses[mn], digits=2))" | |
tab[idx,ncol] = [centered_text(mse)] | |
end | |
# add MEA cutout | |
circ_lw = 3pt | |
mea_rect = 0.6 | |
zoom_rect = mea_rect/8 | |
# mea = compose(context(0w, 0h, 5cm, 5cm), | |
mea = compose(context(), | |
# zoomed selection | |
# (context(), text(0.5,mea_rect/2, "64x64 channel\nHD-MEA", hcenter, vcenter), | |
# fontsize(8pt)), | |
(context(), rectangle(0.5-zoom_rect/2,0.5-zoom_rect/2,zoom_rect,zoom_rect), | |
stroke("red"), fill(nothing)), | |
(context(), line([(1,0), (0.5+zoom_rect/2, 0.5-zoom_rect/2)]), | |
strokedash([1.2mm, 1.2mm]), stroke("red")), | |
(context(), line([(1,1), (0.5+zoom_rect/2, 0.5+zoom_rect/2)]), | |
strokedash([1.2mm, 1.2mm]), stroke("red")), | |
# MEA border | |
(context(), circle(0.5cx, 0.5cy, (1cx-circ_lw)/2), | |
fill(nothing), stroke("black"),linewidth(circ_lw)), | |
(context(), rectangle(0.5 - mea_rect/2,0.5 - mea_rect/2,mea_rect,mea_rect), | |
fill("gray80")), | |
) | |
tab[3,1] = [mea] | |
tab[4,1] = [ compose(context(), | |
text(0.5,0.5,"64x64\nHD-MEA", hcenter), | |
fontsize(7pt))] | |
set_default_graphic_size(W, H) | |
fig = compose(context(), tab) | |
fn = "figure1_channel_comparison" | |
fig |> SVG(joinpath(@__DIR__, "$fn.svg")) | |
@show H,W | |
println("$(round(H/247mm,digits=2)) of a page") | |
# assume 300dpi | |
mm2px = x -> Int(round(uconvert(u"inch",Quantity(x.value,u"mm"))*200/1u"inch",digits=0)) | |
px_w = mm2px(W) | |
px_h = mm2px(H) | |
cmd = "inkscape -w $px_w -h $px_h $fn.svg --export-filename $fn.png" | |
println("to make PNG: $cmd") | |
fig | |
## FIGURE 2 model architecture | |
base_dir = joinpath(@__DIR__, "final-outputs") | |
model_folders = filter(isdir, joinpath.(base_dir, readdir(base_dir))) | |
model_names = map(x->x[end],splitpath.(model_folders)) | |
NUM_IMAGES_PER_MODEL = 10 | |
order = ["originals", "linear_model", "64x64-mlp", | |
"64x64-resnet-mlp", "64x64-convnet"] | |
idxs = map(o->searchsortedfirst(model_names,o), order) | |
model_names = model_names[idxs] | |
model_folders = model_folders[idxs] | |
pretty_names = ["", "linear", "MLP", | |
"resMLP", "resUNet"] | |
ncol = 12 | |
nrow = length(order) | |
W = 183mm | |
H = W/ncol * nrow | |
tab = table(nrow, ncol, 1:nrow, 1:ncol) | |
im_start_col = 2 | |
# model names | |
for (i,mn) in enumerate(pretty_names) | |
tab[i,1] = [centered_text(mn)] | |
end | |
# render images | |
for (i,model_folder) in zip(1:nrow, model_folders) | |
images = get_images_for_model(model_folder) | |
for (idx,j) in enumerate(im_start_col:im_start_col+9) | |
tab[i,j] = [compose(context(), bitmap("image/png", images[idx],0,0,1,1))] | |
end | |
end | |
# add MSE | |
tab[1, ncol-1] = [centered_text("MSE")] | |
for (idx,mn) in zip(2:nrow,model_names[2:end]) | |
mse = "$(round(avg_mse[mn], digits=2))" | |
tab[idx,ncol-1] = [centered_text(mse)] | |
end | |
# add perceptual | |
tab[1, ncol] = [centered_text("Percept")] | |
for (idx,mn) in zip(2:nrow,model_names[2:end]) | |
mse = "$(round(avg_percept_losses[mn], digits=2))" | |
tab[idx,ncol] = [centered_text(mse)] | |
end | |
set_default_graphic_size(W, H) | |
fig = compose(context(), tab) | |
fn = "figure2_model_architecture" | |
fig |> SVG(joinpath(@__DIR__, "$fn.svg")) | |
@show H,W | |
println("$(round(H/247mm,digits=2)) of a page") | |
# assume 300dpi | |
mm2px = x -> Int(round(uconvert(u"inch",Quantity(x.value,u"mm"))*200/1u"inch",digits=0)) | |
px_w = mm2px(W) | |
px_h = mm2px(H) | |
cmd = "inkscape -w $px_w -h $px_h $fn.svg --export-filename $fn.png" | |
println("to make PNG: $cmd") | |
fig | |
## supp figure scramble | |
# base_dir = joinpath(@__DIR__, "final-outputs") | |
base_dir = "/mnt/dropbox/Dropbox/Science/manuscripts/2019_acuity_paper/2020-reconstructions/final-outputs" | |
model_folders = filter(isdir, joinpath.(base_dir, readdir(base_dir))) | |
model_names = map(x->x[end],splitpath.(model_folders)) | |
NUM_IMAGES_PER_MODEL = 10 | |
order = ["originals", "scrambled-targets", "32x32-convnet-descrambling", | |
"32x32-convnet-scrambled-targets", "32x32-convnet-scrambled-both"] | |
idxs = map(o->searchsortedfirst(model_names,o), order) | |
model_names = model_names[idxs] | |
model_folders = model_folders[idxs] | |
pretty_names = ["", "scrambled", "de-\nscrambling", | |
"targets\nscrambled", "both\nscrambled"] | |
ncol = 12 | |
nrow = length(order) | |
W = 183mm | |
H = W/ncol * nrow | |
tab = table(nrow, ncol, 1:nrow, 1:ncol) | |
im_start_col = 2 | |
# model names | |
for (i,mn) in enumerate(pretty_names) | |
tab[i,1] = [centered_text(mn)] | |
end | |
# render images | |
for (i,model_folder) in zip(1:nrow, model_folders) | |
images = get_images_for_model(model_folder) | |
for (idx,j) in enumerate(im_start_col:im_start_col+9) | |
tab[i,j] = [compose(context(), bitmap("image/png", images[idx],0,0,1,1))] | |
end | |
end | |
# add MSE | |
tab[1, ncol] = [centered_text("MSE")] | |
for (idx,mn) in zip(3:nrow,model_names[3:end]) | |
mse = "$(round(avg_mse[mn], digits=2))" | |
tab[idx,ncol] = [centered_text(mse)] | |
end | |
set_default_graphic_size(W, H) | |
fig = compose(context(), tab) | |
fn = "figureSup_scrambling" | |
fig |> SVG(joinpath(@__DIR__, "$fn.svg")) | |
@show H,W | |
println("$(round(H/247mm,digits=2)) of a page") | |
# assume 300dpi | |
mm2px = x -> Int(round(uconvert(u"inch",Quantity(x.value,u"mm"))*200/1u"inch",digits=0)) | |
px_w = mm2px(W) | |
px_h = mm2px(H) | |
cmd = "inkscape -w $px_w -h $px_h $fn.svg --export-filename $fn.png" | |
println("to make PNG: $cmd") | |
fig |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment