Last active
September 6, 2024 08:31
-
-
Save vankesteren/6c141f7cabcd3eb47292d78cfca1804d to your computer and use it in GitHub Desktop.
Permuting data to induce complex bivariate relations.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using StatsBase: sample, mean, cor | |
using LinearAlgebra: norm | |
using Plots, Random | |
""" | |
permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Number = 1e-3, max_iter::Int = 10_000, max_search::Number = 100, verbose::Bool = true) | |
Permute y values to approximate a functional constraint (rule) between x and y. | |
# Arguments | |
- `x::Vector`: The vector of x values | |
- `y::Vector`: The vector of y values | |
- `rule::Function`: Function taking two numbers and outputting a single value | |
- `score::Real`: The target score, i.e., the target value of `sum(rule.(x, y))` | |
- `tol::Real`: The tolerance. If the loss `abs(current_score - score)` is below this value, stop the algorithm. | |
- `max_iter::Int`: Maximum number of iterations. For large datasets, this may need to be increased. | |
- `max_search::Int`: The number of iterations to search for improvements | |
- `verbose::Bool`: Whether to print debug information | |
""" | |
function permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Number=1e-3, max_iter::Int=10_000, max_search::Int=100, verbose::Bool=true) | |
N = length(x) | |
if N != length(y) | |
throw(ArgumentError("x and y should be the same length!")) | |
end | |
# current objective value | |
current_rule = rule.(x, y) | |
current_score = sum(current_rule) | |
current_loss = abs(score - current_score) | |
iter::Int = 0 | |
search_iter::Int = 0 | |
while iter < max_iter | |
# get random index | |
i = sample(1:N) | |
# compute change in score | |
delta_score = rule.(x, y[i]) .+ rule.(x[i], y) .- current_rule .- current_rule[i] | |
# only change if loss improves | |
new_loss, j = findmin(abs.(score .- (current_score .+ delta_score))) | |
if new_loss < current_loss | |
# Found option! make change | |
y[i], y[j] = y[j], y[i] | |
current_rule[i], current_rule[j] = rule(x[i], y[i]), rule(x[j], y[j]) | |
current_score = sum(current_rule) | |
current_loss = abs(score - current_score) | |
if verbose | |
println("Iter $iter | loss $current_loss | $i ↔ $j | score $current_score | search $search_iter") | |
end | |
search_iter = 0 | |
else | |
# increment counter | |
search_iter += 1 | |
end | |
# stopping conditions | |
if search_iter >= max_search | |
if verbose | |
println("\nNo improvement found after $(iter-max_search) iterations.") | |
end | |
break | |
end | |
if current_loss < tol | |
if verbose | |
println("\nAchieved tolerance!") | |
end | |
break | |
end | |
# increment iterations | |
iter += 1 | |
end | |
return nothing | |
end | |
function permutefun!(x::Vector, y::Vector, rule::Function; tol::Number=1e-3, max_iter::Int=10_000, max_search::Number=100, verbose::Bool=true) | |
permutefun!(x, y, rule, length(x); tol, max_iter, max_search, verbose) | |
end | |
function marginplot(x, y, title) | |
layout = @layout [ | |
a _ | |
b{0.8w,0.8h} c | |
] | |
xlim = (minimum(x), maximum(x)) | |
ylim = (minimum(y), maximum(y)) | |
default(fillcolor=:grey, markercolor=:grey, legend=false) | |
plt = plot(layout=layout, link=:none, size=(500, 500), margin=-10Plots.px, plot_title=title) | |
scatter!(x, y, subplot=2, xlim=xlim, ylim=ylim) | |
histogram!(x, nbins=30, subplot=1, orientation=:v, framestyle=:none, bottommargin=-20Plots.px, xlim=xlim) | |
histogram!(y, nbins=30, subplot=3, orientation=:h, framestyle=:none, leftmargin=-40Plots.px, ylim=ylim) | |
return plt | |
end | |
# Generate some data | |
N = 300 | |
Random.seed!(45) | |
x = rand(N) .- 0.5 | |
y = vcat(randn(Int(N / 2)) ./ 6 .- 0.25, randn(Int(N / 2)) ./ 8 .+ 0.35) | |
p1 = marginplot(x, y, "Original data") | |
# Enforce complex constraint | |
x1, y1 = copy(x), copy(y) | |
permutefun!(x1, y1, (xi, yi) -> (xi^4 < yi^2)) | |
p2 = marginplot(x1, y1, "x⁴ < y²") | |
# induce a certain correlation | |
x2, y2 = copy(x), copy(y) | |
xm, ym = mean(x), mean(y) | |
permutefun!(x2, y2, (xi, yi) -> (xi - xm) * (yi - ym), 0.7 * norm(x .- xm) * norm(y .- ym)) | |
cor(x2, y2) | |
p3 = marginplot(x2, y2, "Correlation = .7") | |
# make a hole in the data | |
x3, y3 = copy(x), copy(y) | |
permutefun!(x3, y3, (xi, yi) -> sqrt(xi^2 + yi^2) > 0.3) | |
p4 = marginplot(x3, y3, "Hole of radius 0.3") | |
plot(p1, p2, p3, p4, size=(2000, 500), layout=(1, 4), margin=10Plots.px) | |
savefig("permutefun.pdf") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note that the marginals stay the same while the functions are being applied because the algorithm only permutes, does not change values.