Created
July 19, 2018 04:03
-
-
Save rdeits/713c3e1eddde8cdc4fee028b2bc0a98c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if VERSION >= v"0.7-alpha" | |
using Statistics | |
using Test | |
using Random | |
else | |
using Base.Test | |
end | |
using BenchmarkTools | |
using BenchmarkTools: prettytime, time, prettymemory, memory | |
import MathOptInterface | |
const MOI = MathOptInterface | |
function my_canonical(f::MOI.ScalarAffineFunction) | |
result = MOI.ScalarAffineFunction(copy(f.terms), f.constant) | |
my_canonical!(f) | |
end | |
function my_canonical!(f::MOI.ScalarAffineFunction) | |
if length(f.terms) > 1 | |
sort!(f.terms, by=t -> t.variable_index.value, alg=QuickSort) | |
i1 = 1 | |
i2 = 2 | |
while true | |
if f.terms[i1].variable_index.value == f.terms[i2].variable_index.value | |
f.terms[i1] = MOI.ScalarAffineTerm(f.terms[i1].coefficient + f.terms[i2].coefficient, f.terms[i1].variable_index) | |
i2 += 1 | |
else | |
f.terms[i1 + 1] = f.terms[i2] | |
i1 += 1 | |
i2 += 1 | |
end | |
if i2 > length(f.terms) | |
break | |
end | |
end | |
resize!(f.terms, i1) | |
end | |
f | |
end | |
struct DeduplicationCache{T} | |
coefficients::Dict{MOI.VariableIndex, T} | |
DeduplicationCache{T}() where {T} = new{T}(Dict{MOI.VariableIndex, T}()) | |
end | |
function deduplicate_with_cache!(f::MOI.ScalarAffineFunction{T}, cache::DeduplicationCache) where {T} | |
empty!(cache.coefficients) | |
for term in f.terms | |
cache.coefficients[term.variable_index] = term.coefficient + get(cache.coefficients, term.variable_index, zero(valtype(cache.coefficients))) | |
end | |
empty!(f.terms) | |
for (var, coeff) in cache.coefficients | |
push!(f.terms, MOI.ScalarAffineTerm(coeff, var)) | |
end | |
f | |
end | |
function deduplicate_with_cache(f::MOI.ScalarAffineFunction, cache::DeduplicationCache) | |
result = MOI.ScalarAffineFunction(copy(f.terms), f.constant) | |
deduplicate_with_cache!(result, cache) | |
end | |
# Create a 1-argument version with a statically defined cache, just | |
# for testing purposes | |
deduplicate_with_static_cache = let cache = DeduplicationCache{Float64}() | |
f -> deduplicate_with_cache(f, cache) | |
end | |
deduplicate_with_static_cache! = let cache = DeduplicationCache{Float64}() | |
f -> deduplicate_with_cache!(f, cache) | |
end | |
function random_affine_function(n) | |
indices = MOI.VariableIndex.([rand(1:n) for i in 1:n]) | |
coefficients = randn(n) | |
MOI.ScalarAffineFunction(MOI.ScalarAffineTerm.(coefficients, indices), randn()) | |
end | |
@testset "correctness" begin | |
for i in 1:1000 | |
f = random_affine_function(rand(1:1000)) | |
@inferred(MOI.Utilities.canonical(f)) | |
@inferred(my_canonical(f)) | |
@inferred(deduplicate_with_static_cache(f)) | |
@test MOI.Utilities.canonical(f).terms == my_canonical(f).terms == sort(deduplicate_with_static_cache(f).terms, by = t -> t.variable_index.value) | |
end | |
end | |
methods = [ | |
MOI.Utilities.canonical => "MOI.Utilities.canonical", | |
my_canonical => "my_canonical", | |
my_canonical! => "my_canonical!", | |
deduplicate_with_static_cache => "deduplicate_with_static_cache", | |
deduplicate_with_static_cache! => "deduplicate_with_static_cache!" | |
] | |
sizes = [10, 100, 1000, 10_000, 100_000, 1_000_000, 10_000_000] | |
results = [ | |
(srand(1); @benchmark(($method)(f), setup = (f = random_affine_function($n)), evals = 1)) | |
for (method, name) in methods, n in sizes | |
] | |
print("| n ") | |
for (method, name) in methods | |
print("| ", name, " ") | |
end | |
println("|") | |
print("| --- ") | |
for _ in methods | |
print("| --- ") | |
end | |
println("|") | |
for (j, n) in enumerate(sizes) | |
print("| $n ") | |
for (i, (method, name)) in enumerate(methods) | |
r = results[i, j] | |
print("| $(prettytime(time(median(r)))) ") | |
end | |
println(" |") | |
end | |
println() | |
print("| n ") | |
for (method, name) in methods | |
print("| ", name, " ") | |
end | |
println("|") | |
print("| --- ") | |
for _ in methods | |
print("| --- ") | |
end | |
println("|") | |
for (j, n) in enumerate(sizes) | |
print("| $n ") | |
for (i, (method, name)) in enumerate(methods) | |
r = results[i, j] | |
print("| $(prettymemory(memory(median(r)))) ") | |
end | |
println(" |") | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment