Created
May 3, 2020 23:43
-
-
Save jlmelville/bebf1de0af7fe87f15f94f0e1d852e7a to your computer and use it in GitHub Desktop.
useful functions for running smallvis on multiple datasets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
benchmark <- function(datasets = smallvis_datasets(), fun = smallvis::smallvis, | |
vfun = NULL, mfun = NULL, ...) { | |
varargs <- list(...) | |
if (is.null(varargs$ret_extra)) { | |
varargs$ret_extra <- TRUE | |
} | |
dataset_names <- names(datasets) | |
varargs_orig <- varargs | |
res <- list() | |
for (name in dataset_names) { | |
message(name) | |
varargs$X <- datasets[[name]] | |
# generate data-dependent varargs | |
if (!is.null(vfun)) { | |
xargs <- vfun(varargs, name = name) | |
for (xname in names(xargs)) { | |
varargs[[xname]] <- xargs[[xname]] | |
} | |
} | |
# generate data-dependent method list | |
if (!is.null(mfun)) { | |
svmethod <- varargs$method | |
if (is.character(svmethod)) { | |
svmethod <- list(svmethod) | |
} | |
margs <- mfun(svmethod, varargs, name = name) | |
for (mname in names(margs)) { | |
svmethod[[mname]] <- margs[[mname]] | |
} | |
varargs$method <- svmethod | |
} | |
# perplexity = list(iris = 50, s1k = 333, ...) | |
for (vname in names(varargs)) { | |
if (is.list(varargs[[vname]])) { | |
vvnames <- names(varargs[[vname]]) | |
if (length(vvnames) == length(dataset_names) && | |
all(vvnames == dataset_names)) | |
{ | |
varargs[[vname]] <- varargs[[vname]][[name]] | |
} | |
} | |
} | |
res[[name]] <- do.call(fun, varargs) | |
varargs <- varargs_orig | |
} | |
res | |
} | |
smallvis_datasets <- function() { | |
list( | |
iris = iris, | |
s1k = s1k, | |
oli = oli, | |
frey = frey, | |
coil20 = coil20, | |
mnist6k = mnist6k, | |
fashion6k = fashion6k | |
) | |
} | |
plotall <- function(bench_res, info = "", perplexity = NULL, | |
export_dir = NULL, code = "", include_cost = FALSE, | |
...) { | |
vizdefs <- viz_defaults() | |
for (name in names(bench_res)) { | |
X <- get(name) | |
dvizargs <- vizdefs[[name]] | |
vizargs <- list(...) | |
# override vizargs with defaults | |
for (dvname in names(dvizargs)) { | |
if (is.null(vizargs[[dvname]])) { | |
vizargs[[dvname]] <- dvizargs[[dvname]] | |
} | |
} | |
vizargs$X <- X | |
if (!is.list(bench_res[[name]])) { | |
vizargs$res <- list() | |
vizargs$res$Y <- bench_res[[name]] | |
} | |
else { | |
vizargs$res <- bench_res[[name]] | |
} | |
vizargs$include_cost <- include_cost | |
vizargs$info <- paste0(name, " ", info) | |
if (!is.null(export_dir)) { | |
vizargs$filename = paste0(export_dir, "/", name, "_", code, ".png") | |
} | |
if (!is.null(perplexity)) { | |
vizargs$perplexity <- perplexity | |
} | |
do.call(smallvis_plot, vizargs) | |
} | |
} | |
smallvis_plot <- function(X, res, perplexity = res$perplexity, | |
info = "", include_cost = FALSE, | |
filename = NULL, | |
...) { | |
# list of preamble text then pairs of res key to value | |
# e.g. info = function(res) { paste0("opt-SNE lr = ", res$opt$eta) } | |
if (is.function(info)) { | |
infofn <- info | |
info <- infofn(res) | |
} | |
message(info) | |
title <- info | |
if (include_cost) { | |
title <- paste0(title, " cost = ", formatC(last(res$itercosts))) | |
} | |
vizargs <- list(...) | |
vizargs$title <- title | |
vizargs$coords <- res$Y | |
vizargs$x <- X | |
do.call(vizier::embed_plot, vizargs) | |
if (!is.null(filename)) { | |
save_image(filename) | |
} | |
} | |
save_image <- function(filename) { | |
grDevices::dev.copy(grDevices::png, filename = filename) | |
grDevices::dev.off() | |
} | |
viz_defaults <- function() { | |
list( | |
iris = list(cex = 1, alpha_scale = 0.5), | |
s1k = list(cex = 1, alpha_scale = 0.5), | |
oli = list(cex = 1, alpha_scale = 0.5), | |
frey = list(cex = 1, alpha_scale = 0.25), | |
coil20 = list(cex = 0.75, alpha_scale = 0.25), | |
mnist6k = list(cex = 1, alpha_scale = 0.25), | |
fashion6k = list(cex = 1, alpha_scale = 0.25) | |
) | |
} | |
img_tables <- function(imgdir, codes, datasets = small_data_names(), head_level = 3) { | |
dir = paste0("../", imgdir, "/") | |
header <- paste(replicate(head_level, "#"), collapse = "") | |
for (data in datasets) { | |
message(header, " ", data) | |
message() | |
message("| | |") | |
message(":----------------------------:|:--------------------------:") | |
for (i in 1:length(codes)) { | |
message("![", data, " ", codes[i], "](", dir, data, "_", codes[i], ".png)", appendLF = FALSE) | |
if (i %% 2 == 0) { | |
message() | |
} | |
else { | |
message("|", appendLF = FALSE) | |
} | |
} | |
message() | |
} | |
message() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment