Skip to content

Instantly share code, notes, and snippets.

@mfansler
Created April 18, 2023 04:34
Show Gist options
  • Save mfansler/6607eb64f1e704524b6bad7ab1fb79da to your computer and use it in GitHub Desktop.
Save mfansler/6607eb64f1e704524b6bad7ab1fb79da to your computer and use it in GitHub Desktop.
## see https://stackoverflow.com/q/76039888/570918
diffmat <- function(drow_margin, dcol_margin, freeze=TRUE) {
nrow <- length(drow_margin)
ncol <- length(dcol_margin)
rmat <- matrix(drow_margin, nrow, ncol)
cmat <- matrix(dcol_margin, nrow, ncol, byrow=TRUE)
dmat <- 0.5*(rmat + cmat)
if (freeze) {
## keep satisfactory rows and columns fixed
dmat <- dmat * outer(drow_margin != 0,
dcol_margin != 0)
}
dmat
}
#' Generate random boolean matrix from row and column margins
rmatseek <- function(row_margin, col_margin, max_trials=1000, rate=1) {
require(gtools, include.only=c("logit", "inv.logit"))
require(purrr, include.only="rbernoulli")
n_steps <- 0
pmat <- outer(row_margin, col_margin)
n_entries <- length(pmat)
pmat <- pmat/n_entries
## fill in any immediate logical entries
pmat[outer(row_margin != length(col_margin),
col_margin != length(row_margin)) == 0] <- 1
pmat[outer(row_margin != 0,
col_margin != 0) == 0] <- 0
mat <- rbernoulli(n_entries, pmat)
drow <- rowSums(mat) - row_margin
dcol <- colSums(mat) - col_margin
while (n_steps < max_trials & sum(abs(drow),abs(dcol)) > 0) {
n_steps <- n_steps + 1
pmat <- inv.logit(logit(pmat) - rate*diffmat(drow, dcol))
mat <- rbernoulli(n_entries, pmat)
drow <- rowSums(mat) - row_margin
dcol <- colSums(mat) - col_margin
}
list(mat=mat,
pmat=pmat,
converged=sum(abs(drow),abs(dcol)) == 0,
n_steps=n_steps)
}
## Example
if (interactive()) {
require(magrittr)
set.seed(20230417)
## StackOverflow Question
ROW_SUMS=c(4,2,3,5,3)
COL_SUMS=c(5,1,5,2,4)
rmatseek(ROW_SUMS, COL_SUMS)
## 10x10 permutation matrices
rmatseek(rep(1,10), rep(1,10))
## 20x20 half fill
replicate(100, rmatseek(rep(10,20), rep(10,20), rate=0.5)$n_steps) %>%
summary
## larger sparse
my_rmat <- rbernoulli(1e4, p=0.02) %>% matrix(100, 100)
rmatseek(rowSums(my_rmat), colSums(my_rmat))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment