Last active
August 29, 2015 14:05
-
-
Save ggggggggg/46a7e57ccff05d3a1558 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Base: isempty, copy, eltype, sizehint, push!, union!, pop!, delete!, ==, <, <=, length, start, next, done, | |
setdiff, symdiff, symdiff!, setdiff!, show, first, last, intersect, intersect!, symdiff, symdiff! | |
using Base.Test | |
type IntSetBitVector | |
b::BitVector # n is in IntSet if s[n+1]==true | |
fill1s::Bool | |
IntSetBitVector() = new(falses(256), false) | |
end | |
IntSetBitVector(itr) = (s=IntSetBitVector(); for a in itr; push!(s,a); end; s) | |
IntSetBitVector(itr...) = (s=IntSetBitVector(); for a in itr; push!(s,a); end; s) | |
similar(s::IntSetBitVector) = IntSetBitVector() | |
copy(s::IntSetBitVector) = union!(IntSetBitVector(), s) | |
eltype(s::IntSetBitVector) = Int64 | |
function sizehint(s::IntSetBitVector, top::Integer) | |
if top >= length(s.b) | |
lim = iceil(top/64)*64 | |
oldsize = length(s.b) | |
if oldsize < lim | |
assert(lim%64==0) # this is assumed to be true elsewhere | |
resize!(s.b, lim) | |
s.b[oldsize+1:end] = s.fill1s | |
end | |
end | |
s | |
end | |
function push!(s::IntSetBitVector, n::Integer) | |
if n >= length(s.b) | |
if s.fill1s | |
return s | |
else | |
lim = int(n + div(n,2)) | |
sizehint(s, lim) | |
end | |
elseif n < 0 | |
throw(ArgumentError("IntSetBitVector elements cannot be negative")) | |
end | |
Base.unsafe_bitsetindex!(s.b.chunks, true, n+1) | |
s | |
end | |
function union!(s::IntSetBitVector, ns) | |
for n in ns | |
push!(s, n) | |
end | |
s | |
end | |
function pop!(s::IntSetBitVector, n::Integer, default) | |
if n >= length(s.b) | |
if s.fill1s | |
lim = int(n + div(n,2)) | |
sizehint(s, lim) | |
else | |
return default | |
end | |
end | |
if s.b[n+1] | |
s.b[n+1] = false | |
return n | |
else | |
return default | |
end | |
end | |
function pop!(s::IntSetBitVector, n::Integer) | |
if pop!(s, n, n+1) == n+1 | |
throw(KeyError(n)) | |
end | |
n | |
end | |
# TODO: what should happen when fill1s == true? | |
pop!(s::IntSetBitVector) = pop!(s, last(s)) | |
function delete!(s::IntSetBitVector, n::Integer) | |
pop!(s, n, n) | |
s | |
end | |
function setdiff!(s::IntSetBitVector, ns) | |
for n in ns | |
delete!(s, n) | |
end | |
s | |
end | |
setdiff(a::IntSetBitVector, b::IntSetBitVector) = setdiff!(copy(a),b) | |
symdiff(s1::IntSetBitVector, s2::IntSetBitVector) = | |
(length(s1.b)>= length(s1.b)? symdiff!(copy(s1), s2) : symdiff!(copy(s2), s1)) | |
function empty!(s::IntSetBitVector) | |
fill!(b,false) | |
s.fill1s=false | |
s | |
end | |
function symdiff!(s::IntSetBitVector, n::Integer) | |
if n >= length(s.b) | |
lim = int(n + dim(n,2)) | |
sizehint(s, lim) | |
elseif n < 0 | |
throw(ArgumentError("IntSetBitVector elements cannot be negative")) | |
end | |
s.b[n+1] = !s.b[n+1] | |
s | |
end | |
function symdiff!(s::IntSetBitVector, ns) | |
for n in ns | |
symdiff!(s, n) | |
end | |
s | |
end | |
function copy!(to::IntSetBitVector, from::IntSetBitVector) | |
empty!(to) | |
union!(to, from) | |
end | |
function in(n::Integer, s::IntSetBitVector) | |
if n >= length(s.b) | |
# max IntSetBitVector length is typemax(Int), so highest possible element is | |
# typemax(Int)-1 | |
s.fill1s && n >= 0 && n < typemax(Int) | |
elseif n < 0 | |
false | |
else | |
s.b[n+1] | |
end | |
end | |
start(s::IntSetBitVector) = int64(0) | |
done(s::IntSetBitVector, i) = (!s.fill1s && next(s,i)[1] >= length(s.b)) || i == typemax(Int) | |
function next(s::IntSetBitVector, i) | |
if i >= length(s.b) | |
n = int64(i) | |
else | |
n = findnext(s.b,i+1) | |
n = n == 0 ? length(s.b) : n-1 | |
end | |
(n, n+1) | |
end | |
isempty(s::IntSetBitVector) = !s.fill1s && isempty(s.b) | |
function first(s::IntSetBitVector) | |
n = next(s,0)[1] | |
if n >= length(s.b) | |
error("set must be non-empty") | |
end | |
n | |
end | |
shift!(s::IntSetBitVector) = pop!(s, first(s)) | |
function last(s::IntSetBitVector) | |
if !s.fill1s | |
for i = length(s.b.chunks):-1:1 | |
if s.b.chunks[i] != 0 | |
for j = i*64:-1:(i-1)*64+1 | |
s.b[j] && (return j-1) | |
end | |
error("this shouldn't be possible") | |
end | |
end | |
end | |
error("set has no last element") | |
end | |
length(s::IntSetBitVector) = sum(s.b) + (s.fill1s ? typemax(Int) - length(s.b) : 0) | |
function show(io::IO, s::IntSetBitVector) | |
print(io, "IntSetBitVector([") | |
first = true | |
for n in s | |
if n > length(s.b) | |
break | |
end | |
if !first | |
print(io, ", ") | |
end | |
print(io, n) | |
first = false | |
end | |
if s.fill1s | |
print(io, ", ..., ", typemax(Int)-1, ")") | |
else | |
print(io, "])") | |
end | |
end | |
# Math functions | |
function union!(s::IntSetBitVector, s2::IntSetBitVector) | |
if length(s2.b) > length(s.b) | |
sizehint(s, length(s2.b)) | |
end | |
lim = length(s2.b) | |
for i = 1:div(lim, 64) | |
s.b.chunks[i] |= s2.b.chunks[i] | |
end | |
s2.fill1s && (s.b[lim+1:end] = true) | |
s.fill1s |= s2.fill1s | |
s | |
end | |
union(s1::IntSetBitVector) = copy(s1) | |
union(s1::IntSetBitVector, s2::IntSetBitVector) = (length(s1.b)>= length(s2.b) ? union!(copy(s1), s2) : union!(copy(s2), s1)) | |
union(s1::IntSetBitVector, ss::IntSetBitVector...) = union(s1, union(ss...)) | |
function intersect!(s::IntSetBitVector, s2::IntSetBitVector) | |
if length(s2.b) > length(s.b) | |
sizehint(s, length(s2.b)) | |
end | |
lim = length(s2.b) | |
for i = 1:div(lim, 64) | |
s.b.chunks[i] &= s2.b.chunks[i] | |
end | |
s2.fill1s || (s.b[lim+1:end] = false) | |
s.fill1s &= s2.fill1s | |
s | |
end | |
intersect(s1::IntSetBitVector) = copy(s1) | |
intersect(s1::IntSetBitVector, s2::IntSetBitVector) = | |
(length(s1.b)>= length(s2.b) ? intersect!(copy(s1), s2) : intersect!(copy(s2), s1)) | |
intersect(s1::IntSetBitVector, ss::IntSetBitVector...) = intersect(s1, intersect(ss...)) | |
function complement!(s::IntSetBitVector) | |
s.b = !s.b | |
s.fill1s = !s.fill1s | |
s | |
end | |
complement(s::IntSetBitVector) = complement!(copy(s)) | |
function symdiff!(s::IntSetBitVector, s2::IntSetBitVector) | |
length(s2.b) > length(s.b) && sizehint(s, length(s2.b)) | |
for i = 1:div(length(s2.b),64) | |
s.b.chunks[i] $= s2.b.chunks[i] | |
end | |
if s2.fill1s | |
for i = length(s2.b.chunks)+1:length(s.b.chunks) | |
s.b.chunks[i] = ~s.b.chunks[1] | |
end | |
end | |
s.fill1s $= s2.fill1s | |
s | |
end | |
function ==(s1::IntSetBitVector, s2::IntSetBitVector) | |
s1.fill1s != s2.fill1s && (return false) | |
lim1 = length(s1.b) | |
lim2 = length(s2.b) | |
jointlim = min(lim1,lim2) | |
for i = 1:div(jointlim, 64) | |
s1.b.chunks[i] == s2.b.chunks[i] || (return false) | |
end | |
if lim1 > lim2 | |
findnext(s1.b, !s2.fill1s, lim2+1) == 0 || (return false) | |
elseif lim2 > lim1 | |
findnext(s2.b, !s1.fill1s, lim1+1) == 0 || (return false) | |
end | |
true | |
end | |
issubset(a::IntSetBitVector, b::IntSetBitVector) = isequal(a, intersect(a,b)) | |
<(a::IntSetBitVector, b::IntSetBitVector) = (a<=b) && !isequal(a,b) | |
<=(a::IntSetBitVector, b::IntSetBitVector) = issubset(a, b) | |
@test setdiff(IntSetBitVector(1, 2, 3, 4), IntSetBitVector(2, 4, 5, 6)) == IntSetBitVector(1, 3) | |
@test setdiff(Set(1, 2, 3, 4), Set(2, 4, 5, 6)) == Set(1, 3) | |
@test symdiff(IntSetBitVector(1, 2, 3, 4), IntSetBitVector(2, 4, 5, 6)) == IntSetBitVector(1, 3, 5, 6) | |
@test symdiff(IntSetBitVector(2, 4, 5, 6), IntSetBitVector(1, 2, 3, 4)) == IntSetBitVector(1, 3, 5, 6) | |
@test symdiff(Set(1, 2, 3, 4), Set(2, 4, 5, 6)) == Set(1, 3, 5, 6) | |
s1 = Set(1, 2, 3, 4) | |
setdiff!(s1, Set(2, 4, 5, 6)) | |
@test s1 == Set(1, 3) | |
s2 = IntSetBitVector(1, 2, 3, 4) | |
setdiff!(s2, IntSetBitVector(2, 4, 5, 6)) | |
@test s2 == IntSetBitVector(1, 3) | |
# issue #7851 | |
@test_throws ArgumentError IntSetBitVector(-1) | |
@test !(-1 in IntSetBitVector(0:10)) | |
## IntSetBitVector | |
contents = [0,1,10,20,200,300,1000,10000,10002] | |
s = IntSetBitVector(contents...) | |
@test [j for j in s] == contents | |
@test last(s) == 10002 | |
@test first(s) == 0 | |
@test length(s) == 9 | |
@test pop!(s) == 10002 | |
@test length(s) == 8 | |
@test shift!(s) == 0 | |
@test length(s) == 7 | |
@test !in(0,s) | |
@test !in(10002,s) | |
@test in(10000,s) | |
@test_throws ErrorException first(IntSetBitVector()) | |
@test_throws ErrorException last(IntSetBitVector()) | |
t = copy(s) | |
sizehint(t, 20000) #check that hash does not depend on size of internal Array{Uint32, 1} | |
union(s1::IntSet, s2::IntSet) = (s1.limit >= s2.limit ? union!(copy(s1), s2) : union!(copy(s2), s1)) | |
a = IntSetBitVector(rand(1:10000000,1000)); | |
b = IntSetBitVector(rand(1:10000000,1000)); | |
println("IntSetBitVector") | |
@time [a==b for j = 1:1000]; | |
@time [union(a,b) for j=1:100]; | |
@time [for i in a end for j=1:1000]; | |
@time [intersect(a,b) for j=1:100]; | |
@time [symdiff(a,b) for j = 1:100]; | |
c = IntSet(collect(a)); | |
d = IntSet(collect(b)); | |
println("IntSet") | |
@time [c==d for j = 1:1000]; | |
@time [union(c,d) for j=1:100]; | |
@time [for i in c end for j=1:1000]; | |
@time [intersect(c,d) for j=1:100]; | |
@time [symdiff(c,d) for j = 1:100]; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment